[Multi-User Part 3]: Separate chat sesssions based on authenticated users (#511)

- Add a data model which allows us to store Conversations with users. This does a minimal lift over the current setup, where the underlying data is stored in a JSON file. This maintains parity with that configuration.
- There does _seem_ to be some regression in chat quality, which is most likely attributable to search results.

This will help us with #275. It should become much easier to maintain multiple Conversations in a given table in the backend now. We will have to do some thinking on the UI.
This commit is contained in:
sabaimran
2023-10-26 11:37:41 -07:00
committed by GitHub
parent a8a82d274a
commit 4b6ec248a6
24 changed files with 719 additions and 626 deletions

View File

@@ -1,5 +1,4 @@
from typing import Type, TypeVar, List
import uuid
from datetime import date
from django.db import models
@@ -21,6 +20,13 @@ from database.models import (
GithubConfig,
Embeddings,
GithubRepoConfig,
Conversation,
ConversationProcessorConfig,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
)
from khoj.utils.rawconfig import (
ConversationProcessorConfig as UserConversationProcessorConfig,
)
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
@@ -54,18 +60,17 @@ async def get_or_create_user(token: dict) -> KhojUser:
async def create_google_user(token: dict) -> KhojUser:
user_info = token.get("userinfo")
user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email"))
await user.asave()
await GoogleUser.objects.acreate(
sub=user_info.get("sub"),
azp=user_info.get("azp"),
email=user_info.get("email"),
name=user_info.get("name"),
given_name=user_info.get("given_name"),
family_name=user_info.get("family_name"),
picture=user_info.get("picture"),
locale=user_info.get("locale"),
sub=token.get("sub"),
azp=token.get("azp"),
email=token.get("email"),
name=token.get("name"),
given_name=token.get("given_name"),
family_name=token.get("family_name"),
picture=token.get("picture"),
locale=token.get("locale"),
user=user,
)
@@ -137,6 +142,124 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config
class ConversationAdapters:
@staticmethod
def get_conversation_by_user(user: KhojUser):
conversation = Conversation.objects.filter(user=user)
if conversation.exists():
return conversation.first()
return Conversation.objects.create(user=user)
@staticmethod
async def aget_conversation_by_user(user: KhojUser):
conversation = Conversation.objects.filter(user=user)
if await conversation.aexists():
return await conversation.afirst()
return await Conversation.objects.acreate(user=user)
@staticmethod
def has_any_conversation_config(user: KhojUser):
return ConversationProcessorConfig.objects.filter(user=user).exists()
@staticmethod
def get_openai_conversation_config(user: KhojUser):
return OpenAIProcessorConversationConfig.objects.filter(user=user).first()
@staticmethod
def get_offline_chat_conversation_config(user: KhojUser):
return OfflineChatProcessorConversationConfig.objects.filter(user=user).first()
@staticmethod
def has_valid_offline_conversation_config(user: KhojUser):
return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists()
@staticmethod
def has_valid_openai_conversation_config(user: KhojUser):
return OpenAIProcessorConversationConfig.objects.filter(user=user).exists()
@staticmethod
def get_conversation_config(user: KhojUser):
return ConversationProcessorConfig.objects.filter(user=user).first()
@staticmethod
def save_conversation(user: KhojUser, conversation_log: dict):
conversation = Conversation.objects.filter(user=user)
if conversation.exists():
conversation.update(conversation_log=conversation_log)
else:
Conversation.objects.create(user=user, conversation_log=conversation_log)
@staticmethod
def set_conversation_processor_config(user: KhojUser, new_config: UserConversationProcessorConfig):
conversation_config, _ = ConversationProcessorConfig.objects.get_or_create(user=user)
conversation_config.max_prompt_size = new_config.max_prompt_size
conversation_config.tokenizer = new_config.tokenizer
conversation_config.save()
if new_config.openai:
default_values = {
"api_key": new_config.openai.api_key,
}
if new_config.openai.chat_model:
default_values["chat_model"] = new_config.openai.chat_model
OpenAIProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)
if new_config.offline_chat:
default_values = {
"enable_offline_chat": str(new_config.offline_chat.enable_offline_chat),
}
if new_config.offline_chat.chat_model:
default_values["chat_model"] = new_config.offline_chat.chat_model
OfflineChatProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)
@staticmethod
def get_enabled_conversation_settings(user: KhojUser):
openai_config = ConversationAdapters.get_openai_conversation_config(user)
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user)
return {
"openai": True if openai_config is not None else False,
"offline_chat": True
if (offline_chat_config is not None and offline_chat_config.enable_offline_chat)
else False,
}
@staticmethod
def clear_conversation_config(user: KhojUser):
ConversationProcessorConfig.objects.filter(user=user).delete()
ConversationAdapters.clear_openai_conversation_config(user)
ConversationAdapters.clear_offline_chat_conversation_config(user)
@staticmethod
def clear_openai_conversation_config(user: KhojUser):
OpenAIProcessorConversationConfig.objects.filter(user=user).delete()
@staticmethod
def clear_offline_chat_conversation_config(user: KhojUser):
OfflineChatProcessorConversationConfig.objects.filter(user=user).delete()
@staticmethod
async def has_offline_chat(user: KhojUser):
return await OfflineChatProcessorConversationConfig.objects.filter(
user=user, enable_offline_chat=True
).aexists()
@staticmethod
async def get_offline_chat(user: KhojUser):
return await OfflineChatProcessorConversationConfig.objects.filter(user=user).afirst()
@staticmethod
async def has_openai_chat(user: KhojUser):
return await OpenAIProcessorConversationConfig.objects.filter(user=user).aexists()
@staticmethod
async def get_openai_chat(user: KhojUser):
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()
class EmbeddingsAdapters:
word_filer = WordFilter()
file_filter = FileFilter()