mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 21:29:11 +00:00
[Multi-User Part 3]: Separate chat sesssions based on authenticated users (#511)
- Add a data model which allows us to store Conversations with users. This does a minimal lift over the current setup, where the underlying data is stored in a JSON file. This maintains parity with that configuration. - There does _seem_ to be some regression in chat quality, which is most likely attributable to search results. This will help us with #275. It should become much easier to maintain multiple Conversations in a given table in the backend now. We will have to do some thinking on the UI.
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from typing import Type, TypeVar, List
|
||||
import uuid
|
||||
from datetime import date
|
||||
|
||||
from django.db import models
|
||||
@@ -21,6 +20,13 @@ from database.models import (
|
||||
GithubConfig,
|
||||
Embeddings,
|
||||
GithubRepoConfig,
|
||||
Conversation,
|
||||
ConversationProcessorConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
)
|
||||
from khoj.utils.rawconfig import (
|
||||
ConversationProcessorConfig as UserConversationProcessorConfig,
|
||||
)
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
@@ -54,18 +60,17 @@ async def get_or_create_user(token: dict) -> KhojUser:
|
||||
|
||||
|
||||
async def create_google_user(token: dict) -> KhojUser:
|
||||
user_info = token.get("userinfo")
|
||||
user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
|
||||
user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email"))
|
||||
await user.asave()
|
||||
await GoogleUser.objects.acreate(
|
||||
sub=user_info.get("sub"),
|
||||
azp=user_info.get("azp"),
|
||||
email=user_info.get("email"),
|
||||
name=user_info.get("name"),
|
||||
given_name=user_info.get("given_name"),
|
||||
family_name=user_info.get("family_name"),
|
||||
picture=user_info.get("picture"),
|
||||
locale=user_info.get("locale"),
|
||||
sub=token.get("sub"),
|
||||
azp=token.get("azp"),
|
||||
email=token.get("email"),
|
||||
name=token.get("name"),
|
||||
given_name=token.get("given_name"),
|
||||
family_name=token.get("family_name"),
|
||||
picture=token.get("picture"),
|
||||
locale=token.get("locale"),
|
||||
user=user,
|
||||
)
|
||||
|
||||
@@ -137,6 +142,124 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||
return config
|
||||
|
||||
|
||||
class ConversationAdapters:
|
||||
@staticmethod
|
||||
def get_conversation_by_user(user: KhojUser):
|
||||
conversation = Conversation.objects.filter(user=user)
|
||||
if conversation.exists():
|
||||
return conversation.first()
|
||||
return Conversation.objects.create(user=user)
|
||||
|
||||
@staticmethod
|
||||
async def aget_conversation_by_user(user: KhojUser):
|
||||
conversation = Conversation.objects.filter(user=user)
|
||||
if await conversation.aexists():
|
||||
return await conversation.afirst()
|
||||
return await Conversation.objects.acreate(user=user)
|
||||
|
||||
@staticmethod
|
||||
def has_any_conversation_config(user: KhojUser):
|
||||
return ConversationProcessorConfig.objects.filter(user=user).exists()
|
||||
|
||||
@staticmethod
|
||||
def get_openai_conversation_config(user: KhojUser):
|
||||
return OpenAIProcessorConversationConfig.objects.filter(user=user).first()
|
||||
|
||||
@staticmethod
|
||||
def get_offline_chat_conversation_config(user: KhojUser):
|
||||
return OfflineChatProcessorConversationConfig.objects.filter(user=user).first()
|
||||
|
||||
@staticmethod
|
||||
def has_valid_offline_conversation_config(user: KhojUser):
|
||||
return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists()
|
||||
|
||||
@staticmethod
|
||||
def has_valid_openai_conversation_config(user: KhojUser):
|
||||
return OpenAIProcessorConversationConfig.objects.filter(user=user).exists()
|
||||
|
||||
@staticmethod
|
||||
def get_conversation_config(user: KhojUser):
|
||||
return ConversationProcessorConfig.objects.filter(user=user).first()
|
||||
|
||||
@staticmethod
|
||||
def save_conversation(user: KhojUser, conversation_log: dict):
|
||||
conversation = Conversation.objects.filter(user=user)
|
||||
if conversation.exists():
|
||||
conversation.update(conversation_log=conversation_log)
|
||||
else:
|
||||
Conversation.objects.create(user=user, conversation_log=conversation_log)
|
||||
|
||||
@staticmethod
|
||||
def set_conversation_processor_config(user: KhojUser, new_config: UserConversationProcessorConfig):
|
||||
conversation_config, _ = ConversationProcessorConfig.objects.get_or_create(user=user)
|
||||
conversation_config.max_prompt_size = new_config.max_prompt_size
|
||||
conversation_config.tokenizer = new_config.tokenizer
|
||||
conversation_config.save()
|
||||
|
||||
if new_config.openai:
|
||||
default_values = {
|
||||
"api_key": new_config.openai.api_key,
|
||||
}
|
||||
if new_config.openai.chat_model:
|
||||
default_values["chat_model"] = new_config.openai.chat_model
|
||||
|
||||
OpenAIProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)
|
||||
|
||||
if new_config.offline_chat:
|
||||
default_values = {
|
||||
"enable_offline_chat": str(new_config.offline_chat.enable_offline_chat),
|
||||
}
|
||||
|
||||
if new_config.offline_chat.chat_model:
|
||||
default_values["chat_model"] = new_config.offline_chat.chat_model
|
||||
|
||||
OfflineChatProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)
|
||||
|
||||
@staticmethod
|
||||
def get_enabled_conversation_settings(user: KhojUser):
|
||||
openai_config = ConversationAdapters.get_openai_conversation_config(user)
|
||||
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user)
|
||||
|
||||
return {
|
||||
"openai": True if openai_config is not None else False,
|
||||
"offline_chat": True
|
||||
if (offline_chat_config is not None and offline_chat_config.enable_offline_chat)
|
||||
else False,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def clear_conversation_config(user: KhojUser):
|
||||
ConversationProcessorConfig.objects.filter(user=user).delete()
|
||||
ConversationAdapters.clear_openai_conversation_config(user)
|
||||
ConversationAdapters.clear_offline_chat_conversation_config(user)
|
||||
|
||||
@staticmethod
|
||||
def clear_openai_conversation_config(user: KhojUser):
|
||||
OpenAIProcessorConversationConfig.objects.filter(user=user).delete()
|
||||
|
||||
@staticmethod
|
||||
def clear_offline_chat_conversation_config(user: KhojUser):
|
||||
OfflineChatProcessorConversationConfig.objects.filter(user=user).delete()
|
||||
|
||||
@staticmethod
|
||||
async def has_offline_chat(user: KhojUser):
|
||||
return await OfflineChatProcessorConversationConfig.objects.filter(
|
||||
user=user, enable_offline_chat=True
|
||||
).aexists()
|
||||
|
||||
@staticmethod
|
||||
async def get_offline_chat(user: KhojUser):
|
||||
return await OfflineChatProcessorConversationConfig.objects.filter(user=user).afirst()
|
||||
|
||||
@staticmethod
|
||||
async def has_openai_chat(user: KhojUser):
|
||||
return await OpenAIProcessorConversationConfig.objects.filter(user=user).aexists()
|
||||
|
||||
@staticmethod
|
||||
async def get_openai_chat(user: KhojUser):
|
||||
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()
|
||||
|
||||
|
||||
class EmbeddingsAdapters:
|
||||
word_filer = WordFilter()
|
||||
file_filter = FileFilter()
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
# Generated by Django 4.2.5 on 2023-10-18 05:31
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0006_embeddingsdates"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name="conversationprocessorconfig",
|
||||
name="conversation",
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="conversationprocessorconfig",
|
||||
name="enable_offline_chat",
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="conversationprocessorconfig",
|
||||
name="max_prompt_size",
|
||||
field=models.IntegerField(blank=True, default=None, null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="conversationprocessorconfig",
|
||||
name="tokenizer",
|
||||
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="conversationprocessorconfig",
|
||||
name="user",
|
||||
field=models.ForeignKey(
|
||||
default=1, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
|
||||
),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="OpenAIProcessorConversationConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("api_key", models.CharField(max_length=200)),
|
||||
("chat_model", models.CharField(max_length=200)),
|
||||
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="OfflineChatProcessorConversationConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("enable_offline_chat", models.BooleanField(default=False)),
|
||||
("chat_model", models.CharField(default="llama-2-7b-chat.ggmlv3.q4_0.bin", max_length=200)),
|
||||
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="Conversation",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("conversation_log", models.JSONField()),
|
||||
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,17 @@
|
||||
# Generated by Django 4.2.5 on 2023-10-18 16:46
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0007_remove_conversationprocessorconfig_conversation_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="conversation",
|
||||
name="conversation_log",
|
||||
field=models.JSONField(default=dict),
|
||||
),
|
||||
]
|
||||
@@ -82,9 +82,27 @@ class LocalPlaintextConfig(BaseModel):
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class ConversationProcessorConfig(BaseModel):
|
||||
conversation = models.JSONField()
|
||||
class OpenAIProcessorConversationConfig(BaseModel):
|
||||
api_key = models.CharField(max_length=200)
|
||||
chat_model = models.CharField(max_length=200)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class OfflineChatProcessorConversationConfig(BaseModel):
|
||||
enable_offline_chat = models.BooleanField(default=False)
|
||||
chat_model = models.CharField(max_length=200, default="llama-2-7b-chat.ggmlv3.q4_0.bin")
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class ConversationProcessorConfig(BaseModel):
|
||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
conversation_log = models.JSONField(default=dict)
|
||||
|
||||
|
||||
class Embeddings(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user