diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index bf046e42..ea9c773f 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1293,32 +1293,71 @@ class ConversationAdapters: return ChatModel.objects.filter().first() @staticmethod - async def aget_default_chat_model(user: KhojUser = None, fallback_chat_model: Optional[ChatModel] = None): - """Get default conversation config. Prefer chat model by server admin > agent > user > first created chat model""" + async def aget_default_chat_model( + user: KhojUser = None, fallback_chat_model: Optional[ChatModel] = None, fast: Optional[bool] = None + ): + """ + Get the chat model to use. Prefer chat model by server admin > agent > user > first created chat model + + Fast is a trinary flag to indicate preference for fast, deep or default chat model configured by the server admin. + If fast is True, prefer fast models over deep models when both are configured. + If fast is False, prefer deep models over fast models when both are configured. + If fast is None, do not consider speed preference and use the default model selection logic. + + If fallback_chat_model is provided, it will be used as a fallback if server chat settings are not configured. + Else if user settings are found use that. + Otherwise the first chat model will be used. + """ # Get the server chat settings server_chat_settings: ServerChatSettings = ( await ServerChatSettings.objects.filter() .prefetch_related( - "chat_default", "chat_default__ai_model_api", "chat_advanced", "chat_advanced__ai_model_api" + "chat_default", + "chat_default__ai_model_api", + "chat_advanced", + "chat_advanced__ai_model_api", + "think_free_fast", + "think_free_fast__ai_model_api", + "think_free_deep", + "think_free_deep__ai_model_api", + "think_paid_fast", + "think_paid_fast__ai_model_api", + "think_paid_deep", + "think_paid_deep__ai_model_api", ) .afirst() ) is_subscribed = await ais_user_subscribed(user) if user else False if server_chat_settings: - # If the user is subscribed and the advanced model is enabled, return the advanced model - if is_subscribed and server_chat_settings.chat_advanced: - return server_chat_settings.chat_advanced - # If the default model is set, return it - if server_chat_settings.chat_default: - return server_chat_settings.chat_default + # If the user is subscribed + if is_subscribed: + # If fast is requested and fast paid model is available + if server_chat_settings.think_paid_fast and fast is True: + return server_chat_settings.think_paid_fast + # Else if fast is not requested and deep paid model is available + elif server_chat_settings.think_paid_deep and fast is not None: + return server_chat_settings.think_paid_deep + # Else if advanced model is available + elif server_chat_settings.chat_advanced: + return server_chat_settings.chat_advanced + else: + # If fast is requested and fast free model is available + if server_chat_settings.think_free_fast and fast: + return server_chat_settings.think_free_fast + # Else if fast is not requested and deep free model is available + elif server_chat_settings.think_free_deep: + return server_chat_settings.think_free_deep + # Else if default model is available + elif server_chat_settings.chat_default: + return server_chat_settings.chat_default # Revert to an explicit fallback model if the server chat settings are not set if fallback_chat_model: # The chat model may not be full loaded from the db, so explicitly load it here return await ChatModel.objects.filter(id=fallback_chat_model.id).prefetch_related("ai_model_api").afirst() - # Get the user's chat settings, if the server chat settings are not set + # Get the user's chat settings, if both the server chat settings and the fallback model are not set user_chat_settings = ( (await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst()) if user diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index ceccbfa1..5400c89e 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -278,6 +278,10 @@ class ServerChatSettingsAdmin(unfold_admin.ModelAdmin): list_display = ( "chat_default", "chat_advanced", + "think_free_fast", + "think_free_deep", + "think_paid_fast", + "think_paid_deep", "web_scraper", ) diff --git a/src/khoj/database/migrations/0094_serverchatsettings_think_free_deep_and_more.py b/src/khoj/database/migrations/0094_serverchatsettings_think_free_deep_and_more.py new file mode 100644 index 00000000..a27eeabb --- /dev/null +++ b/src/khoj/database/migrations/0094_serverchatsettings_think_free_deep_and_more.py @@ -0,0 +1,61 @@ +# Generated by Django 5.1.10 on 2025-08-26 20:47 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0093_remove_localorgconfig_user_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="serverchatsettings", + name="think_free_deep", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="think_free_deep", + to="database.chatmodel", + ), + ), + migrations.AddField( + model_name="serverchatsettings", + name="think_free_fast", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="think_free_fast", + to="database.chatmodel", + ), + ), + migrations.AddField( + model_name="serverchatsettings", + name="think_paid_deep", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="think_paid_deep", + to="database.chatmodel", + ), + ), + migrations.AddField( + model_name="serverchatsettings", + name="think_paid_fast", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="think_paid_fast", + to="database.chatmodel", + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 535fa7e7..90ed67a8 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -472,6 +472,18 @@ class ServerChatSettings(DbBaseModel): chat_advanced = models.ForeignKey( ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced" ) + think_free_fast = models.ForeignKey( + ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_free_fast" + ) + think_free_deep = models.ForeignKey( + ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_free_deep" + ) + think_paid_fast = models.ForeignKey( + ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_paid_fast" + ) + think_paid_deep = models.ForeignKey( + ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_paid_deep" + ) web_scraper = models.ForeignKey( WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper" ) @@ -480,6 +492,10 @@ class ServerChatSettings(DbBaseModel): error = {} if self.chat_default and self.chat_default.price_tier != PriceTier.FREE: error["chat_default"] = "Set the price tier of this chat model to free or use a free tier chat model." + if self.think_free_fast and self.think_free_fast.price_tier != PriceTier.FREE: + error["think_free_fast"] = "Set the price tier of this chat model to free or use a free tier chat model." + if self.think_free_deep and self.think_free_deep.price_tier != PriceTier.FREE: + error["think_free_deep"] = "Set the price tier of this chat model to free or use a free tier chat model." if error: raise ValidationError(error) diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 47f6304f..59702e36 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -160,6 +160,7 @@ async def generate_python_code( query_files=query_files, user=user, agent_chat_model=agent_chat_model, + fast_model=False, tracer=tracer, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index c7d8acdd..13d201d3 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -286,7 +286,7 @@ async def acreate_title_from_history( title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history) with timer("Chat actor: Generate title from conversation history", logger): - response = await send_message_to_model_wrapper(title_generation_prompt, user=user) + response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True) return response.text.strip() @@ -298,7 +298,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: title_generation_prompt = prompts.subject_generation.format(query=query) with timer("Chat actor: Generate title from query", logger): - response = await send_message_to_model_wrapper(title_generation_prompt, user=user) + response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True) return response.text.strip() @@ -321,7 +321,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: with timer("Chat actor: Check if safe prompt", logger): response = await send_message_to_model_wrapper( - safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck + safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck, fast_model=True ) response = response.text.strip() @@ -410,6 +410,7 @@ async def aget_data_sources_and_output_format( user=user, query_files=query_files, agent_chat_model=agent_chat_model, + fast_model=False, tracer=tracer, ) @@ -499,6 +500,7 @@ async def infer_webpage_urls( user=user, query_files=query_files, agent_chat_model=agent_chat_model, + fast_model=False, tracer=tracer, ) @@ -564,6 +566,7 @@ async def generate_online_subqueries( user=user, query_files=query_files, agent_chat_model=agent_chat_model, + fast_model=False, tracer=tracer, ) @@ -625,7 +628,12 @@ async def aschedule_query( ) raw_response = await send_message_to_model_wrapper( - crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer + crontime_prompt, + query_images=query_images, + response_type="json_object", + fast_model=False, + user=user, + tracer=tracer, ) # Validate that the response is a non-empty, JSON-serializable list @@ -666,6 +674,7 @@ async def extract_relevant_info( prompts.system_prompt_extract_relevant_information, user=user, agent_chat_model=agent_chat_model, + fast_model=True, tracer=tracer, ) return response.text.strip() @@ -709,6 +718,7 @@ async def extract_relevant_summary( user=user, query_images=query_images, agent_chat_model=agent_chat_model, + fast_model=True, tracer=tracer, ) return response.text.strip() @@ -880,6 +890,7 @@ async def generate_better_diagram_description( user=user, query_files=query_files, agent_chat_model=agent_chat_model, + fast_model=False, tracer=tracer, ) response = response.text.strip() @@ -908,7 +919,11 @@ async def generate_excalidraw_diagram_from_description( with timer("Chat actor: Generate excalidraw diagram", logger): raw_response = await send_message_to_model_wrapper( - query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer + query=excalidraw_diagram_generation, + user=user, + agent_chat_model=agent_chat_model, + fast_model=False, + tracer=tracer, ) raw_response_text = clean_json(raw_response.text) try: @@ -1031,6 +1046,7 @@ async def generate_better_mermaidjs_diagram_description( user=user, query_files=query_files, agent_chat_model=agent_chat_model, + fast_model=False, tracer=tracer, ) response_text = response.text.strip() @@ -1059,7 +1075,11 @@ async def generate_mermaidjs_diagram_from_description( with timer("Chat actor: Generate Mermaid.js diagram", logger): raw_response = await send_message_to_model_wrapper( - query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer + query=mermaidjs_diagram_generation, + user=user, + agent_chat_model=agent_chat_model, + fast_model=False, + tracer=tracer, ) return clean_mermaidjs(raw_response.text.strip()) @@ -1120,6 +1140,7 @@ async def generate_better_image_prompt( query_files=query_files, chat_history=conversation_history, agent_chat_model=agent_chat_model, + fast_model=False, user=user, response_type="json_object", response_schema=ImagePromptResponse, @@ -1302,6 +1323,7 @@ async def extract_questions( query_files=query_files, response_type="json_object", response_schema=DocumentQueries, + fast_model=False, user=user, tracer=tracer, ) @@ -1420,6 +1442,7 @@ async def send_message_to_model_wrapper( response_schema: BaseModel = None, tools: List[ToolDefinition] = None, deepthought: bool = False, + fast_model: Optional[bool] = None, user: KhojUser = None, query_images: List[str] = None, context: str = "", @@ -1428,7 +1451,7 @@ async def send_message_to_model_wrapper( agent_chat_model: ChatModel = None, tracer: dict = {}, ): - chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model) + chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model, fast=fast_model) vision_available = chat_model.vision_enabled if not vision_available and query_images: logger.warning(f"Vision is not enabled for default model: {chat_model.name}.") diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index d43593f6..1bd6ba71 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -163,6 +163,7 @@ async def apick_next_tool( chat_history=chat_and_research_history, tools=tools, deepthought=True, + fast_model=False, user=user, query_images=query_images, query_files=query_files,