From 731700ac437f426054032d21327b924d4ccca1a0 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 26 Nov 2025 17:46:29 -0800 Subject: [PATCH] Support fallback deep, fast chat models via server chat settings Overview --- This change enables specifying fallback chat models for each task type (fast, deep, default) and user type (free, paid). Previously we did not fallback to other chat models if the chat model assigned for a task failed. Details --- You can now specify multiple ServerChatSettings via the Admin Panel with their usage priority. If the highest priority chat model for the task, user type fails, the task is assigned to a lower priority chat model configured for the current user and task type. This change also reduces the retry attempts for openai chat actor models from 3 to 2 as: - multiple fallback server chat settings can now be created. So reducing retries with same model reduces latency. - 2 attempts is inline with retry attempts with other model types (gemini, anthropic) --- src/khoj/database/adapters/__init__.py | 65 +++++++ src/khoj/database/admin.py | 2 + .../0097_serverchatsettings_priority.py | 23 +++ src/khoj/database/models/__init__.py | 22 +++ src/khoj/processor/conversation/openai/gpt.py | 2 +- .../processor/conversation/openai/utils.py | 4 +- src/khoj/processor/conversation/utils.py | 64 +++++++ src/khoj/routers/helpers.py | 167 ++++++++++++------ 8 files changed, 293 insertions(+), 56 deletions(-) create mode 100644 src/khoj/database/migrations/0097_serverchatsettings_priority.py diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 5d286b32..6257de7b 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1419,6 +1419,71 @@ class ConversationAdapters: max_tokens = chat_model.max_prompt_size return max_tokens + @staticmethod + async def aget_chat_models_with_fallbacks(slot: ServerChatSettings.ChatModelSlot) -> list[ChatModel]: + """ + Get chat models for a specific subscription, speed preference from all ServerChatSettings, ordered by priority. + Used for fallback logic when a chat model fails. + + Args: + slot: The chat model slot to get based on user subscription, speed preference (e.g., THINK_FREE_FAST, CHAT_DEFAULT) + + Returns: + List of ChatModel objects ordered by ServerChatSettings priority (lower first) + """ + # Map slot enum to field name and prefetch related + slot_field = slot.value + prefetch_fields = [slot_field, f"{slot_field}__ai_model_api"] + + # Get all server chat settings ordered by priority + all_settings = [ + settings + async for settings in ServerChatSettings.objects.filter() + .prefetch_related(*prefetch_fields) + .order_by("priority") + .aiterator() + ] + + # Extract the chat model for the requested slot from each settings + chat_models: list[ChatModel] = [] + seen_model_ids: set[int] = set() + for settings in all_settings: + chat_model = getattr(settings, slot_field, None) + if chat_model and chat_model.id not in seen_model_ids: + chat_models.append(chat_model) + seen_model_ids.add(chat_model.id) + + return chat_models + + @staticmethod + async def aget_chat_model_slot(user: KhojUser = None, fast: Optional[bool] = None): + """ + Determine which chat model slot to use based on user subscription and speed preference. + + Args: + user: The user making the request + fast: Trinary flag for speed preference (True=fast, False=deep, None=default) + + Returns: + The appropriate ChatModelSlot enum value, or None if no slot matches + """ + is_subscribed = await ais_user_subscribed(user) if user else False + + if is_subscribed: + if fast is True: + return ServerChatSettings.ChatModelSlot.THINK_PAID_FAST + elif fast is False: + return ServerChatSettings.ChatModelSlot.THINK_PAID_DEEP + else: + return ServerChatSettings.ChatModelSlot.CHAT_ADVANCED + else: + if fast is True: + return ServerChatSettings.ChatModelSlot.THINK_FREE_FAST + elif fast is False: + return ServerChatSettings.ChatModelSlot.THINK_FREE_DEEP + else: + return ServerChatSettings.ChatModelSlot.CHAT_DEFAULT + @staticmethod async def aget_server_webscraper(): server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst() diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 04ab3410..872c57ca 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -287,6 +287,7 @@ class SearchModelConfigAdmin(unfold_admin.ModelAdmin): @admin.register(ServerChatSettings) class ServerChatSettingsAdmin(unfold_admin.ModelAdmin): list_display = ( + "priority", "chat_default", "chat_advanced", "think_free_fast", @@ -295,6 +296,7 @@ class ServerChatSettingsAdmin(unfold_admin.ModelAdmin): "think_paid_deep", "web_scraper", ) + ordering = ("priority",) @admin.register(WebScraper) diff --git a/src/khoj/database/migrations/0097_serverchatsettings_priority.py b/src/khoj/database/migrations/0097_serverchatsettings_priority.py new file mode 100644 index 00000000..abdafe24 --- /dev/null +++ b/src/khoj/database/migrations/0097_serverchatsettings_priority.py @@ -0,0 +1,23 @@ +# Generated by Django 5.1.14 on 2025-11-27 01:09 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0096_mcpserver"), + ] + + operations = [ + migrations.AddField( + model_name="serverchatsettings", + name="priority", + field=models.IntegerField( + blank=True, + default=None, + help_text="Priority of the server chat settings. Lower numbers run first.", + null=True, + unique=True, + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 4d759506..b7f869a5 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -468,6 +468,16 @@ class WebScraper(DbBaseModel): class ServerChatSettings(DbBaseModel): + class ChatModelSlot(models.TextChoices): + """Enum for the different chat model slots in ServerChatSettings""" + + CHAT_DEFAULT = "chat_default" + CHAT_ADVANCED = "chat_advanced" + THINK_FREE_FAST = "think_free_fast" + THINK_FREE_DEEP = "think_free_deep" + THINK_PAID_FAST = "think_paid_fast" + THINK_PAID_DEEP = "think_paid_deep" + chat_default = models.ForeignKey( ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default" ) @@ -489,6 +499,13 @@ class ServerChatSettings(DbBaseModel): web_scraper = models.ForeignKey( WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper" ) + priority = models.IntegerField( + default=None, + null=True, + blank=True, + unique=True, + help_text="Priority of the server chat settings. Lower numbers run first.", + ) def clean(self): error = {} @@ -503,6 +520,11 @@ class ServerChatSettings(DbBaseModel): def save(self, *args, **kwargs): self.clean() + + if self.priority is None: + max_priority = ServerChatSettings.objects.aggregate(models.Max("priority"))["priority__max"] + self.priority = max_priority + 1 if max_priority else 1 + super().save(*args, **kwargs) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 62d6aaba..82482fb0 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -24,7 +24,7 @@ from khoj.utils.helpers import ToolDefinition logger = logging.getLogger(__name__) -def send_message_to_model( +def openai_send_message_to_model( messages, api_key, model: str, diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 9218227d..4f4bd358 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -84,7 +84,7 @@ def _extract_text_for_instructions(content: Union[str, List, Dict, None]) -> str | retry_if_exception_type(ValueError) ), wait=wait_random_exponential(min=1, max=10), - stop=stop_after_attempt(3), + stop=stop_after_attempt(2), before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) @@ -438,7 +438,7 @@ async def chat_completion_with_backoff( | retry_if_exception_type(ValueError) ), wait=wait_random_exponential(min=1, max=10), - stop=stop_after_attempt(3), + stop=stop_after_attempt(2), before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e63d2e71..5672134f 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -11,10 +11,14 @@ from enum import Enum from io import BytesIO from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +import httpx import PIL.Image import pyjson5 import requests import yaml +from anthropic import APIError as AnthropicAPIError +from anthropic import RateLimitError as AnthropicRateLimitError +from google.genai import errors as gerrors from langchain_core.messages.chat import ChatMessage from pydantic import BaseModel, ConfigDict, ValidationError @@ -93,6 +97,66 @@ model_to_prompt_size = { model_to_tokenizer: Dict[str, str] = {} +class RetryableModelError(Exception): + """ + Exception raised when a chat model fails with a retryable error. + This is used to trigger fallback to the next model in the priority list. + + Wraps provider-specific retryable errors like: + - OpenAI: RateLimitError, InternalServerError, APITimeoutError + - Anthropic: RateLimitError, APIError + - Google/Gemini: API errors with codes 429, 502, 503, 504 + """ + + def __init__(self, message: str, original_exception: Exception = None, model_name: str = None): + super().__init__(message) + self.original_exception = original_exception + self.model_name = model_name + + def __str__(self): + model_info = f" (model: {self.model_name})" if self.model_name else "" + return f"{super().__str__()}{model_info}" + + +def is_retryable_exception(exception: BaseException) -> bool: + """ + Check if an exception is retryable and should trigger fallback to another model. + """ + # OpenAI exceptions + if hasattr(exception, "__module__") and exception.__module__ and "openai" in exception.__module__: + import openai + + if isinstance( + exception, + ( + openai._exceptions.APITimeoutError, + openai._exceptions.RateLimitError, + openai._exceptions.InternalServerError, + ), + ): + return True + + # Anthropic exceptions + if isinstance(exception, (AnthropicRateLimitError, AnthropicAPIError)): + return True + + # Google/Gemini exceptions + if isinstance(exception, (gerrors.APIError, gerrors.ClientError)): + # Check for specific error codes that are retryable + if hasattr(exception, "code") and exception.code in [429, 502, 503, 504]: + return True + + # Network errors + if isinstance(exception, (httpx.TimeoutException, httpx.NetworkError)): + return True + + # Empty or no response by model over API results in ValueError + if isinstance(exception, ValueError): + return True + + return False + + class AgentMessage(BaseModel): role: Literal["user", "assistant", "system", "environment"] content: Union[str, List] diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 8e64c1ea..e5d2de57 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -89,19 +89,21 @@ from khoj.processor.conversation.google.gemini_chat import ( ) from khoj.processor.conversation.openai.gpt import ( converse_openai, - send_message_to_model, + openai_send_message_to_model, ) from khoj.processor.conversation.utils import ( ChatEvent, OperatorRun, ResearchIteration, ResponseWithThought, + RetryableModelError, clean_json, clean_mermaidjs, construct_chat_history, construct_question_history, defilter_query, generate_chatml_messages_with_context, + is_retryable_exception, ) from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled from khoj.routers.email import is_resend_enabled, send_task_email @@ -1452,61 +1454,26 @@ async def execute_search( return results -async def send_message_to_model_wrapper( - # Context - query: str, - query_files: str = None, - query_images: List[str] = None, - context: str = "", - chat_history: list[ChatMessageModel] = [], - system_message: str = "", - # Model Config - response_type: str = "text", - response_schema: BaseModel = None, - tools: List[ToolDefinition] = None, - deepthought: bool = False, - fast_model: Optional[bool] = None, - agent_chat_model: ChatModel = None, - # User - user: KhojUser = None, - # Tracer - tracer: dict = {}, -): - chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model, fast=fast_model) - vision_available = chat_model.vision_enabled - if not vision_available and query_images: - logger.warning(f"Vision is not enabled for default model: {chat_model.name}.") - vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() - if vision_enabled_config: - chat_model = vision_enabled_config - vision_available = True - if vision_available and query_images: - logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.") - - max_tokens = await ConversationAdapters.aget_max_context_size(chat_model, user) - chat_model_name = chat_model.name - tokenizer = chat_model.tokenizer +def send_message_to_model( + chat_model: ChatModel, + truncated_messages: List[ChatMessage], + response_type: str, + response_schema: BaseModel, + tools: List[ToolDefinition], + deepthought: bool, + tracer: dict, +) -> ResponseWithThought: + """ + Call a specific chat model with the given parameters. + This is a helper function used by send_message_to_model_wrapper for the fallback loop. + """ model_type = chat_model.model_type - vision_available = chat_model.vision_enabled + chat_model_name = chat_model.name api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - truncated_messages = generate_chatml_messages_with_context( - user_message=query, - query_files=query_files, - query_images=query_images, - context_message=context, - chat_history=chat_history, - system_message=system_message, - model_name=chat_model_name, - model_type=model_type, - tokenizer_name=tokenizer, - max_prompt_size=max_tokens, - vision_enabled=vision_available, - ) - if model_type == ChatModel.ModelType.OPENAI: - return send_message_to_model( + return openai_send_message_to_model( messages=truncated_messages, api_key=api_key, model=chat_model_name, @@ -1542,7 +1509,101 @@ async def send_message_to_model_wrapper( tracer=tracer, ) else: - raise HTTPException(status_code=500, detail="Invalid conversation config") + raise HTTPException(status_code=500, detail=f"Invalid model type: {model_type}") + + +async def send_message_to_model_wrapper( + # Context + query: str, + query_files: str = None, + query_images: List[str] = None, + context: str = "", + chat_history: list[ChatMessageModel] = [], + system_message: str = "", + # Model Config + response_type: str = "text", + response_schema: BaseModel = None, + tools: List[ToolDefinition] = None, + deepthought: bool = False, + fast_model: Optional[bool] = None, + agent_chat_model: ChatModel = None, + # User + user: KhojUser = None, + # Tracer + tracer: dict = {}, +): + # Get primary chat model + primary_chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model( + user, agent_chat_model, fast=fast_model + ) + vision_available = primary_chat_model.vision_enabled + + # Handle vision model override if needed + if not vision_available and query_images: + logger.warning(f"Vision is not enabled for default model: {primary_chat_model.name}.") + vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() + if vision_enabled_config: + primary_chat_model = vision_enabled_config + vision_available = True + if vision_available and query_images: + logger.info(f"Using {primary_chat_model.name} model to understand {len(query_images)} images.") + + # Get fallback models for the appropriate slot + slot = await ConversationAdapters.aget_chat_model_slot(user, fast=fast_model) + fallback_models = await ConversationAdapters.aget_chat_models_with_fallbacks(slot) + # Filter out the primary model from fallbacks to avoid duplicate attempts + fallback_models = [m for m in fallback_models if m.id != primary_chat_model.id] + + # Build list of models to try: primary first, then fallbacks + models_to_try = [primary_chat_model] + fallback_models + + last_exception: Optional[Exception] = None + for i, chat_model in enumerate(models_to_try): + is_last_model = i == len(models_to_try) - 1 + + # Prepare messages for this specific model + max_tokens = await ConversationAdapters.aget_max_context_size(chat_model, user) + truncated_messages = generate_chatml_messages_with_context( + user_message=query, + query_files=query_files, + query_images=query_images, + context_message=context, + chat_history=chat_history, + system_message=system_message, + model_name=chat_model.name, + model_type=chat_model.model_type, + tokenizer_name=chat_model.tokenizer, + max_prompt_size=max_tokens, + vision_enabled=chat_model.vision_enabled if not query_images else vision_available, + ) + + try: + return send_message_to_model( + chat_model=chat_model, + truncated_messages=truncated_messages, + response_type=response_type, + response_schema=response_schema, + tools=tools, + deepthought=deepthought, + tracer=tracer, + ) + except Exception as e: + last_exception = e + if is_retryable_exception(e): + if is_last_model: + logger.error(f"All chat models failed. Last error from {chat_model.name}: {e}") + else: + logger.warning(f"Chat model {chat_model.name} failed with retryable error: {e}. Trying next model.") + continue + # Non-retryable errors should be raised immediately + raise + + # If we get here, all models failed with retryable errors + raise RetryableModelError( + message=f"All {len(models_to_try)} chat models failed", + original_exception=last_exception, + model_name=models_to_try[-1].name if models_to_try else None, + ) def send_message_to_model_wrapper_sync( @@ -1581,7 +1642,7 @@ def send_message_to_model_wrapper_sync( ) if model_type == ChatModel.ModelType.OPENAI: - return send_message_to_model( + return openai_send_message_to_model( messages=truncated_messages, api_key=api_key, api_base_url=api_base_url,