From 4976b244a4824d049bc33c13c8dd1438f3dd3102 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 26 Aug 2025 14:44:51 -0700 Subject: [PATCH] Set fast, deep think models for intermediary steps via admin panel Overview Enable improving speed and cost of chat by setting fast, deep think models for intermediate steps and non user facing operations. Details - Allow decoupling default chat models from models used for intermediate steps by setting server chat settings on admin panel - Use deep think models for most intermediate steps like tool selection, subquery construction etc. in default and research mode - Use fast think models for webpage read, chat title setting etc. Faster webpage read should improve conversation latency --- src/khoj/database/adapters/__init__.py | 59 +++++++++++++++--- src/khoj/database/admin.py | 4 ++ ...erchatsettings_think_free_deep_and_more.py | 61 +++++++++++++++++++ src/khoj/database/models/__init__.py | 16 +++++ src/khoj/processor/tools/run_code.py | 1 + src/khoj/routers/helpers.py | 37 ++++++++--- src/khoj/routers/research.py | 1 + 7 files changed, 162 insertions(+), 17 deletions(-) create mode 100644 src/khoj/database/migrations/0094_serverchatsettings_think_free_deep_and_more.py diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index bf046e42..ea9c773f 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1293,32 +1293,71 @@ class ConversationAdapters: return ChatModel.objects.filter().first() @staticmethod - async def aget_default_chat_model(user: KhojUser = None, fallback_chat_model: Optional[ChatModel] = None): - """Get default conversation config. Prefer chat model by server admin > agent > user > first created chat model""" + async def aget_default_chat_model( + user: KhojUser = None, fallback_chat_model: Optional[ChatModel] = None, fast: Optional[bool] = None + ): + """ + Get the chat model to use. Prefer chat model by server admin > agent > user > first created chat model + + Fast is a trinary flag to indicate preference for fast, deep or default chat model configured by the server admin. + If fast is True, prefer fast models over deep models when both are configured. + If fast is False, prefer deep models over fast models when both are configured. + If fast is None, do not consider speed preference and use the default model selection logic. + + If fallback_chat_model is provided, it will be used as a fallback if server chat settings are not configured. + Else if user settings are found use that. + Otherwise the first chat model will be used. + """ # Get the server chat settings server_chat_settings: ServerChatSettings = ( await ServerChatSettings.objects.filter() .prefetch_related( - "chat_default", "chat_default__ai_model_api", "chat_advanced", "chat_advanced__ai_model_api" + "chat_default", + "chat_default__ai_model_api", + "chat_advanced", + "chat_advanced__ai_model_api", + "think_free_fast", + "think_free_fast__ai_model_api", + "think_free_deep", + "think_free_deep__ai_model_api", + "think_paid_fast", + "think_paid_fast__ai_model_api", + "think_paid_deep", + "think_paid_deep__ai_model_api", ) .afirst() ) is_subscribed = await ais_user_subscribed(user) if user else False if server_chat_settings: - # If the user is subscribed and the advanced model is enabled, return the advanced model - if is_subscribed and server_chat_settings.chat_advanced: - return server_chat_settings.chat_advanced - # If the default model is set, return it - if server_chat_settings.chat_default: - return server_chat_settings.chat_default + # If the user is subscribed + if is_subscribed: + # If fast is requested and fast paid model is available + if server_chat_settings.think_paid_fast and fast is True: + return server_chat_settings.think_paid_fast + # Else if fast is not requested and deep paid model is available + elif server_chat_settings.think_paid_deep and fast is not None: + return server_chat_settings.think_paid_deep + # Else if advanced model is available + elif server_chat_settings.chat_advanced: + return server_chat_settings.chat_advanced + else: + # If fast is requested and fast free model is available + if server_chat_settings.think_free_fast and fast: + return server_chat_settings.think_free_fast + # Else if fast is not requested and deep free model is available + elif server_chat_settings.think_free_deep: + return server_chat_settings.think_free_deep + # Else if default model is available + elif server_chat_settings.chat_default: + return server_chat_settings.chat_default # Revert to an explicit fallback model if the server chat settings are not set if fallback_chat_model: # The chat model may not be full loaded from the db, so explicitly load it here return await ChatModel.objects.filter(id=fallback_chat_model.id).prefetch_related("ai_model_api").afirst() - # Get the user's chat settings, if the server chat settings are not set + # Get the user's chat settings, if both the server chat settings and the fallback model are not set user_chat_settings = ( (await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst()) if user diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index ceccbfa1..5400c89e 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -278,6 +278,10 @@ class ServerChatSettingsAdmin(unfold_admin.ModelAdmin): list_display = ( "chat_default", "chat_advanced", + "think_free_fast", + "think_free_deep", + "think_paid_fast", + "think_paid_deep", "web_scraper", ) diff --git a/src/khoj/database/migrations/0094_serverchatsettings_think_free_deep_and_more.py b/src/khoj/database/migrations/0094_serverchatsettings_think_free_deep_and_more.py new file mode 100644 index 00000000..a27eeabb --- /dev/null +++ b/src/khoj/database/migrations/0094_serverchatsettings_think_free_deep_and_more.py @@ -0,0 +1,61 @@ +# Generated by Django 5.1.10 on 2025-08-26 20:47 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0093_remove_localorgconfig_user_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="serverchatsettings", + name="think_free_deep", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="think_free_deep", + to="database.chatmodel", + ), + ), + migrations.AddField( + model_name="serverchatsettings", + name="think_free_fast", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="think_free_fast", + to="database.chatmodel", + ), + ), + migrations.AddField( + model_name="serverchatsettings", + name="think_paid_deep", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="think_paid_deep", + to="database.chatmodel", + ), + ), + migrations.AddField( + model_name="serverchatsettings", + name="think_paid_fast", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="think_paid_fast", + to="database.chatmodel", + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 535fa7e7..90ed67a8 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -472,6 +472,18 @@ class ServerChatSettings(DbBaseModel): chat_advanced = models.ForeignKey( ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced" ) + think_free_fast = models.ForeignKey( + ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_free_fast" + ) + think_free_deep = models.ForeignKey( + ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_free_deep" + ) + think_paid_fast = models.ForeignKey( + ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_paid_fast" + ) + think_paid_deep = models.ForeignKey( + ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_paid_deep" + ) web_scraper = models.ForeignKey( WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper" ) @@ -480,6 +492,10 @@ class ServerChatSettings(DbBaseModel): error = {} if self.chat_default and self.chat_default.price_tier != PriceTier.FREE: error["chat_default"] = "Set the price tier of this chat model to free or use a free tier chat model." + if self.think_free_fast and self.think_free_fast.price_tier != PriceTier.FREE: + error["think_free_fast"] = "Set the price tier of this chat model to free or use a free tier chat model." + if self.think_free_deep and self.think_free_deep.price_tier != PriceTier.FREE: + error["think_free_deep"] = "Set the price tier of this chat model to free or use a free tier chat model." if error: raise ValidationError(error) diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 47f6304f..59702e36 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -160,6 +160,7 @@ async def generate_python_code( query_files=query_files, user=user, agent_chat_model=agent_chat_model, + fast_model=False, tracer=tracer, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index c7d8acdd..13d201d3 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -286,7 +286,7 @@ async def acreate_title_from_history( title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history) with timer("Chat actor: Generate title from conversation history", logger): - response = await send_message_to_model_wrapper(title_generation_prompt, user=user) + response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True) return response.text.strip() @@ -298,7 +298,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: title_generation_prompt = prompts.subject_generation.format(query=query) with timer("Chat actor: Generate title from query", logger): - response = await send_message_to_model_wrapper(title_generation_prompt, user=user) + response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True) return response.text.strip() @@ -321,7 +321,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: with timer("Chat actor: Check if safe prompt", logger): response = await send_message_to_model_wrapper( - safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck + safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck, fast_model=True ) response = response.text.strip() @@ -410,6 +410,7 @@ async def aget_data_sources_and_output_format( user=user, query_files=query_files, agent_chat_model=agent_chat_model, + fast_model=False, tracer=tracer, ) @@ -499,6 +500,7 @@ async def infer_webpage_urls( user=user, query_files=query_files, agent_chat_model=agent_chat_model, + fast_model=False, tracer=tracer, ) @@ -564,6 +566,7 @@ async def generate_online_subqueries( user=user, query_files=query_files, agent_chat_model=agent_chat_model, + fast_model=False, tracer=tracer, ) @@ -625,7 +628,12 @@ async def aschedule_query( ) raw_response = await send_message_to_model_wrapper( - crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer + crontime_prompt, + query_images=query_images, + response_type="json_object", + fast_model=False, + user=user, + tracer=tracer, ) # Validate that the response is a non-empty, JSON-serializable list @@ -666,6 +674,7 @@ async def extract_relevant_info( prompts.system_prompt_extract_relevant_information, user=user, agent_chat_model=agent_chat_model, + fast_model=True, tracer=tracer, ) return response.text.strip() @@ -709,6 +718,7 @@ async def extract_relevant_summary( user=user, query_images=query_images, agent_chat_model=agent_chat_model, + fast_model=True, tracer=tracer, ) return response.text.strip() @@ -880,6 +890,7 @@ async def generate_better_diagram_description( user=user, query_files=query_files, agent_chat_model=agent_chat_model, + fast_model=False, tracer=tracer, ) response = response.text.strip() @@ -908,7 +919,11 @@ async def generate_excalidraw_diagram_from_description( with timer("Chat actor: Generate excalidraw diagram", logger): raw_response = await send_message_to_model_wrapper( - query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer + query=excalidraw_diagram_generation, + user=user, + agent_chat_model=agent_chat_model, + fast_model=False, + tracer=tracer, ) raw_response_text = clean_json(raw_response.text) try: @@ -1031,6 +1046,7 @@ async def generate_better_mermaidjs_diagram_description( user=user, query_files=query_files, agent_chat_model=agent_chat_model, + fast_model=False, tracer=tracer, ) response_text = response.text.strip() @@ -1059,7 +1075,11 @@ async def generate_mermaidjs_diagram_from_description( with timer("Chat actor: Generate Mermaid.js diagram", logger): raw_response = await send_message_to_model_wrapper( - query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer + query=mermaidjs_diagram_generation, + user=user, + agent_chat_model=agent_chat_model, + fast_model=False, + tracer=tracer, ) return clean_mermaidjs(raw_response.text.strip()) @@ -1120,6 +1140,7 @@ async def generate_better_image_prompt( query_files=query_files, chat_history=conversation_history, agent_chat_model=agent_chat_model, + fast_model=False, user=user, response_type="json_object", response_schema=ImagePromptResponse, @@ -1302,6 +1323,7 @@ async def extract_questions( query_files=query_files, response_type="json_object", response_schema=DocumentQueries, + fast_model=False, user=user, tracer=tracer, ) @@ -1420,6 +1442,7 @@ async def send_message_to_model_wrapper( response_schema: BaseModel = None, tools: List[ToolDefinition] = None, deepthought: bool = False, + fast_model: Optional[bool] = None, user: KhojUser = None, query_images: List[str] = None, context: str = "", @@ -1428,7 +1451,7 @@ async def send_message_to_model_wrapper( agent_chat_model: ChatModel = None, tracer: dict = {}, ): - chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model) + chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model, fast=fast_model) vision_available = chat_model.vision_enabled if not vision_available and query_images: logger.warning(f"Vision is not enabled for default model: {chat_model.name}.") diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index d43593f6..1bd6ba71 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -163,6 +163,7 @@ async def apick_next_tool( chat_history=chat_and_research_history, tools=tools, deepthought=True, + fast_model=False, user=user, query_images=query_images, query_files=query_files,