diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index adb93dfd..33107d2d 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -39,6 +39,7 @@ from khoj.database.models import ( PublicConversation, ReflectiveQuestion, SearchModelConfig, + ServerChatSettings, SpeechToTextModelOptions, Subscription, TextToImageModelConfig, @@ -702,11 +703,36 @@ class ConversationAdapters: @staticmethod def get_default_conversation_config(): - return ChatModelOptions.objects.filter().first() + server_chat_settings = ServerChatSettings.objects.first() + if server_chat_settings is None or server_chat_settings.default_model is None: + return ChatModelOptions.objects.filter().first() + return server_chat_settings.default_model @staticmethod async def aget_default_conversation_config(): - return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() + server_chat_settings: ServerChatSettings = ( + await ServerChatSettings.objects.filter() + .prefetch_related("default_model", "default_model__openai_config") + .afirst() + ) + if server_chat_settings is None or server_chat_settings.default_model is None: + return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() + return server_chat_settings.default_model + + @staticmethod + async def aget_summarizer_conversation_config(): + server_chat_settings: ServerChatSettings = ( + await ServerChatSettings.objects.filter() + .prefetch_related( + "summarizer_model", "default_model", "default_model__openai_config", "summarizer_model__openai_config" + ) + .afirst() + ) + if server_chat_settings is None or ( + server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None + ): + return await ChatModelOptions.objects.filter().afirst() + return server_chat_settings.summarizer_model or server_chat_settings.default_model @staticmethod def create_conversation_from_public_conversation( diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 0b6e42dd..c7adb1ea 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -22,6 +22,7 @@ from khoj.database.models import ( ProcessLock, ReflectiveQuestion, SearchModelConfig, + ServerChatSettings, SpeechToTextModelOptions, Subscription, TextToImageModelConfig, @@ -88,6 +89,7 @@ admin.site.register(TextToImageModelConfig) admin.site.register(ClientApplication) admin.site.register(GithubConfig) admin.site.register(NotionConfig) +admin.site.register(ServerChatSettings) @admin.register(Agent) diff --git a/src/khoj/database/migrations/0042_serverchatsettings.py b/src/khoj/database/migrations/0042_serverchatsettings.py new file mode 100644 index 00000000..a58b9729 --- /dev/null +++ b/src/khoj/database/migrations/0042_serverchatsettings.py @@ -0,0 +1,46 @@ +# Generated by Django 4.2.10 on 2024-04-29 11:04 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0041_merge_20240505_1234"), + ] + + operations = [ + migrations.CreateModel( + name="ServerChatSettings", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "default_model", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="default_model", + to="database.chatmodeloptions", + ), + ), + ( + "summarizer_model", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="summarizer_model", + to="database.chatmodeloptions", + ), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 6921fcae..64741018 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -158,6 +158,15 @@ class GithubRepoConfig(BaseModel): github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig") +class ServerChatSettings(BaseModel): + default_model = models.ForeignKey( + ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="default_model" + ) + summarizer_model = models.ForeignKey( + ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="summarizer_model" + ) + + class LocalOrgConfig(BaseModel): input_files = models.JSONField(default=list, null=True) input_filter = models.JSONField(default=list, null=True) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 46120801..172ebba3 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -613,11 +613,17 @@ async def websocket_endpoint( if ConversationCommand.Webpage in conversation_commands: try: - online_results = await read_webpages(defiltered_query, meta_log, location, send_status_update) + direct_web_pages = await read_webpages(defiltered_query, meta_log, location, send_status_update) webpages = [] - for query in online_results: - for webpage in online_results[query]["webpages"]: + for query in direct_web_pages: + if online_results.get(query): + online_results[query]["webpages"] = direct_web_pages[query]["webpages"] + else: + online_results[query] = {"webpages": direct_web_pages[query]["webpages"]} + + for webpage in direct_web_pages[query]["webpages"]: webpages.append(webpage["link"]) + await send_status_update(f"**📚 Read web pages**: {webpages}") except ValueError as e: logger.warning( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 27bff722..5c1dc25d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -392,9 +392,13 @@ async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]: corpus=corpus.strip(), ) + summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config() + with timer("Chat actor: Extract relevant information from data", logger): response = await send_message_to_model_wrapper( - extract_relevant_information, prompts.system_prompt_extract_relevant_information + extract_relevant_information, + prompts.system_prompt_extract_relevant_information, + chat_model_option=summarizer_model, ) return response.strip() @@ -449,8 +453,11 @@ async def send_message_to_model_wrapper( message: str, system_message: str = "", response_type: str = "text", + chat_model_option: ChatModelOptions = None, ): - conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config() + conversation_config: ChatModelOptions = ( + chat_model_option or await ConversationAdapters.aget_default_conversation_config() + ) if conversation_config is None: raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")