Merge pull request #554 from khoj-ai/fix/issues-with-prod-chat

Fix misc. issues with chat configuration
This commit is contained in:
sabaimran
2023-11-18 14:45:06 -08:00
committed by GitHub
5 changed files with 49 additions and 11 deletions

View File

@@ -240,10 +240,18 @@ class ConversationAdapters:
def get_openai_conversation_config(): def get_openai_conversation_config():
return OpenAIProcessorConversationConfig.objects.filter().first() return OpenAIProcessorConversationConfig.objects.filter().first()
@staticmethod
async def aget_openai_conversation_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst()
@staticmethod @staticmethod
def get_offline_chat_conversation_config(): def get_offline_chat_conversation_config():
return OfflineChatProcessorConversationConfig.objects.filter().first() return OfflineChatProcessorConversationConfig.objects.filter().first()
@staticmethod
async def aget_offline_chat_conversation_config():
return await OfflineChatProcessorConversationConfig.objects.filter().afirst()
@staticmethod @staticmethod
def has_valid_offline_conversation_config(): def has_valid_offline_conversation_config():
return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists() return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists()
@@ -267,10 +275,21 @@ class ConversationAdapters:
return None return None
return config.setting return config.setting
@staticmethod
async def aget_conversation_config(user: KhojUser):
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting
@staticmethod @staticmethod
def get_default_conversation_config(): def get_default_conversation_config():
return ChatModelOptions.objects.filter().first() return ChatModelOptions.objects.filter().first()
@staticmethod
async def aget_default_conversation_config():
return await ChatModelOptions.objects.filter().afirst()
@staticmethod @staticmethod
def save_conversation(user: KhojUser, conversation_log: dict): def save_conversation(user: KhojUser, conversation_log: dict):
conversation = Conversation.objects.filter(user=user) conversation = Conversation.objects.filter(user=user)
@@ -320,10 +339,6 @@ class ConversationAdapters:
async def get_openai_chat_config(): async def get_openai_chat_config():
return await OpenAIProcessorConversationConfig.objects.filter().afirst() return await OpenAIProcessorConversationConfig.objects.filter().afirst()
@staticmethod
async def aget_default_conversation_config():
return await ChatModelOptions.objects.filter().afirst()
class EntryAdapters: class EntryAdapters:
word_filer = WordFilter() word_filer = WordFilter()

View File

@@ -1,5 +1,6 @@
# Standard Packages # Standard Packages
import logging import logging
import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional
@@ -31,6 +32,10 @@ def extract_questions(
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
""" """
def _valid_question(question: str):
return not is_none_or_empty(question) and question != "[]"
# Extract Past User Message and Inferred Questions from Conversation Log # Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join( chat_history = "".join(
[ [
@@ -70,7 +75,7 @@ def extract_questions(
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
try: try:
questions = ( split_questions = (
response.content.strip(empty_escape_sequences) response.content.strip(empty_escape_sequences)
.replace("['", '["') .replace("['", '["')
.replace("']", '"]') .replace("']", '"]')
@@ -79,9 +84,18 @@ def extract_questions(
.replace('"]', "") .replace('"]', "")
.split('", "') .split('", "')
) )
questions = []
for question in split_questions:
if question not in questions and _valid_question(question):
questions.append(question)
if is_none_or_empty(questions):
raise ValueError("GPT returned empty JSON")
except: except:
logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}") logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text] questions = [text]
logger.debug(f"Extracted Questions by GPT: {questions}") logger.debug(f"Extracted Questions by GPT: {questions}")
return questions return questions

View File

@@ -55,6 +55,7 @@ from database.models import (
Entry as DbEntry, Entry as DbEntry,
GithubConfig, GithubConfig,
NotionConfig, NotionConfig,
ChatModelOptions,
) )
@@ -122,7 +123,7 @@ async def map_config_to_db(config: FullConfig, user: KhojUser):
def _initialize_config(): def _initialize_config():
if state.config is None: if state.config is None:
state.config = FullConfig() state.config = FullConfig()
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"]) state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
@api.get("/config/data", response_model=FullConfig) @api.get("/config/data", response_model=FullConfig)
@@ -669,7 +670,16 @@ async def extract_references_and_questions(
# Infer search queries from user message # Infer search queries from user message
with timer("Extracting search queries took", logger): with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled. # If we've reached here, either the user has enabled offline chat or the openai model is enabled.
if await ConversationAdapters.ahas_offline_chat(): offline_chat_config = await ConversationAdapters.aget_offline_chat_conversation_config()
conversation_config = await ConversationAdapters.aget_conversation_config(user)
if conversation_config is None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
if (
offline_chat_config
and offline_chat_config.enabled
and conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE
):
using_offline_chat = True using_offline_chat = True
offline_chat = await ConversationAdapters.get_offline_chat() offline_chat = await ConversationAdapters.get_offline_chat()
chat_model = offline_chat.chat_model chat_model = offline_chat.chat_model
@@ -681,7 +691,7 @@ async def extract_references_and_questions(
inferred_queries = extract_questions_offline( inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
) )
elif await ConversationAdapters.has_openai_chat(): elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config() openai_chat_config = await ConversationAdapters.get_openai_chat_config()
openai_chat = await ConversationAdapters.get_openai_chat() openai_chat = await ConversationAdapters.get_openai_chat()
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key
@@ -706,7 +716,6 @@ async def extract_references_and_questions(
common=common, common=common,
) )
) )
# Dedupe the results again, as duplicates may be returned across queries.
result_list = text_search.deduplicated_search_responses(result_list) result_list = text_search.deduplicated_search_responses(result_list)
compiled_references = [item.additional["compiled"] for item in result_list] compiled_references = [item.additional["compiled"] for item in result_list]

View File

@@ -163,7 +163,7 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
else: else:
hit_ids.add(hit.corpus_id) hit_ids.add(hit.corpus_id)
yield SearchResponse.parse_obj( yield SearchResponse.model_validate(
{ {
"entry": hit.entry, "entry": hit.entry,
"score": hit.score, "score": hit.score,

View File

@@ -39,7 +39,7 @@ def load_config_from_file(yaml_config_file: Path) -> dict:
def parse_config_from_string(yaml_config: dict) -> FullConfig: def parse_config_from_string(yaml_config: dict) -> FullConfig:
"Parse and validate config in YML string" "Parse and validate config in YML string"
return FullConfig.parse_obj(yaml_config) return FullConfig.model_validate(yaml_config)
def parse_config_from_file(yaml_config_file): def parse_config_from_file(yaml_config_file):