diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 4bcb9c8e..4141d3bb 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -240,10 +240,18 @@ class ConversationAdapters: def get_openai_conversation_config(): return OpenAIProcessorConversationConfig.objects.filter().first() + @staticmethod + async def aget_openai_conversation_config(): + return await OpenAIProcessorConversationConfig.objects.filter().afirst() + @staticmethod def get_offline_chat_conversation_config(): return OfflineChatProcessorConversationConfig.objects.filter().first() + @staticmethod + async def aget_offline_chat_conversation_config(): + return await OfflineChatProcessorConversationConfig.objects.filter().afirst() + @staticmethod def has_valid_offline_conversation_config(): return OfflineChatProcessorConversationConfig.objects.filter(enabled=True).exists() @@ -267,10 +275,21 @@ class ConversationAdapters: return None return config.setting + @staticmethod + async def aget_conversation_config(user: KhojUser): + config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() + if not config: + return None + return config.setting + @staticmethod def get_default_conversation_config(): return ChatModelOptions.objects.filter().first() + @staticmethod + async def aget_default_conversation_config(): + return await ChatModelOptions.objects.filter().afirst() + @staticmethod def save_conversation(user: KhojUser, conversation_log: dict): conversation = Conversation.objects.filter(user=user) @@ -320,10 +339,6 @@ class ConversationAdapters: async def get_openai_chat_config(): return await OpenAIProcessorConversationConfig.objects.filter().afirst() - @staticmethod - async def aget_default_conversation_config(): - return await ChatModelOptions.objects.filter().afirst() - class EntryAdapters: word_filer = WordFilter() diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index b86ebc6b..31cfda1e 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -1,5 +1,6 @@ # Standard Packages import logging +import json from datetime import datetime, timedelta from typing import Optional @@ -31,6 +32,10 @@ def extract_questions( """ Infer search queries to retrieve relevant notes to answer user query """ + + def _valid_question(question: str): + return not is_none_or_empty(question) and question != "[]" + # Extract Past User Message and Inferred Questions from Conversation Log chat_history = "".join( [ @@ -70,7 +75,7 @@ def extract_questions( # Extract, Clean Message from GPT's Response try: - questions = ( + split_questions = ( response.content.strip(empty_escape_sequences) .replace("['", '["') .replace("']", '"]') @@ -79,9 +84,18 @@ def extract_questions( .replace('"]', "") .split('", "') ) + questions = [] + + for question in split_questions: + if question not in questions and _valid_question(question): + questions.append(question) + + if is_none_or_empty(questions): + raise ValueError("GPT returned empty JSON") except: logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}") questions = [text] + logger.debug(f"Extracted Questions by GPT: {questions}") return questions diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index b384d8a3..6d67fcbe 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -55,6 +55,7 @@ from database.models import ( Entry as DbEntry, GithubConfig, NotionConfig, + ChatModelOptions, ) @@ -122,7 +123,7 @@ async def map_config_to_db(config: FullConfig, user: KhojUser): def _initialize_config(): if state.config is None: state.config = FullConfig() - state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"]) + state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"]) @api.get("/config/data", response_model=FullConfig) @@ -669,7 +670,16 @@ async def extract_references_and_questions( # Infer search queries from user message with timer("Extracting search queries took", logger): # If we've reached here, either the user has enabled offline chat or the openai model is enabled. - if await ConversationAdapters.ahas_offline_chat(): + offline_chat_config = await ConversationAdapters.aget_offline_chat_conversation_config() + conversation_config = await ConversationAdapters.aget_conversation_config(user) + if conversation_config is None: + conversation_config = await ConversationAdapters.aget_default_conversation_config() + openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() + if ( + offline_chat_config + and offline_chat_config.enabled + and conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE + ): using_offline_chat = True offline_chat = await ConversationAdapters.get_offline_chat() chat_model = offline_chat.chat_model @@ -681,7 +691,7 @@ async def extract_references_and_questions( inferred_queries = extract_questions_offline( defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False ) - elif await ConversationAdapters.has_openai_chat(): + elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = await ConversationAdapters.get_openai_chat_config() openai_chat = await ConversationAdapters.get_openai_chat() api_key = openai_chat_config.api_key @@ -706,7 +716,6 @@ async def extract_references_and_questions( common=common, ) ) - # Dedupe the results again, as duplicates may be returned across queries. result_list = text_search.deduplicated_search_responses(result_list) compiled_references = [item.additional["compiled"] for item in result_list] diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index f07eb580..7e295903 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -163,7 +163,7 @@ def deduplicated_search_responses(hits: List[SearchResponse]): else: hit_ids.add(hit.corpus_id) - yield SearchResponse.parse_obj( + yield SearchResponse.model_validate( { "entry": hit.entry, "score": hit.score, diff --git a/src/khoj/utils/yaml.py b/src/khoj/utils/yaml.py index abfe270a..36546688 100644 --- a/src/khoj/utils/yaml.py +++ b/src/khoj/utils/yaml.py @@ -39,7 +39,7 @@ def load_config_from_file(yaml_config_file: Path) -> dict: def parse_config_from_string(yaml_config: dict) -> FullConfig: "Parse and validate config in YML string" - return FullConfig.parse_obj(yaml_config) + return FullConfig.model_validate(yaml_config) def parse_config_from_file(yaml_config_file):