diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 5d286b32..6257de7b 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1419,6 +1419,71 @@ class ConversationAdapters: max_tokens = chat_model.max_prompt_size return max_tokens + @staticmethod + async def aget_chat_models_with_fallbacks(slot: ServerChatSettings.ChatModelSlot) -> list[ChatModel]: + """ + Get chat models for a specific subscription, speed preference from all ServerChatSettings, ordered by priority. + Used for fallback logic when a chat model fails. + + Args: + slot: The chat model slot to get based on user subscription, speed preference (e.g., THINK_FREE_FAST, CHAT_DEFAULT) + + Returns: + List of ChatModel objects ordered by ServerChatSettings priority (lower first) + """ + # Map slot enum to field name and prefetch related + slot_field = slot.value + prefetch_fields = [slot_field, f"{slot_field}__ai_model_api"] + + # Get all server chat settings ordered by priority + all_settings = [ + settings + async for settings in ServerChatSettings.objects.filter() + .prefetch_related(*prefetch_fields) + .order_by("priority") + .aiterator() + ] + + # Extract the chat model for the requested slot from each settings + chat_models: list[ChatModel] = [] + seen_model_ids: set[int] = set() + for settings in all_settings: + chat_model = getattr(settings, slot_field, None) + if chat_model and chat_model.id not in seen_model_ids: + chat_models.append(chat_model) + seen_model_ids.add(chat_model.id) + + return chat_models + + @staticmethod + async def aget_chat_model_slot(user: KhojUser = None, fast: Optional[bool] = None): + """ + Determine which chat model slot to use based on user subscription and speed preference. + + Args: + user: The user making the request + fast: Trinary flag for speed preference (True=fast, False=deep, None=default) + + Returns: + The appropriate ChatModelSlot enum value, or None if no slot matches + """ + is_subscribed = await ais_user_subscribed(user) if user else False + + if is_subscribed: + if fast is True: + return ServerChatSettings.ChatModelSlot.THINK_PAID_FAST + elif fast is False: + return ServerChatSettings.ChatModelSlot.THINK_PAID_DEEP + else: + return ServerChatSettings.ChatModelSlot.CHAT_ADVANCED + else: + if fast is True: + return ServerChatSettings.ChatModelSlot.THINK_FREE_FAST + elif fast is False: + return ServerChatSettings.ChatModelSlot.THINK_FREE_DEEP + else: + return ServerChatSettings.ChatModelSlot.CHAT_DEFAULT + @staticmethod async def aget_server_webscraper(): server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst() diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 04ab3410..872c57ca 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -287,6 +287,7 @@ class SearchModelConfigAdmin(unfold_admin.ModelAdmin): @admin.register(ServerChatSettings) class ServerChatSettingsAdmin(unfold_admin.ModelAdmin): list_display = ( + "priority", "chat_default", "chat_advanced", "think_free_fast", @@ -295,6 +296,7 @@ class ServerChatSettingsAdmin(unfold_admin.ModelAdmin): "think_paid_deep", "web_scraper", ) + ordering = ("priority",) @admin.register(WebScraper) diff --git a/src/khoj/database/migrations/0097_serverchatsettings_priority.py b/src/khoj/database/migrations/0097_serverchatsettings_priority.py new file mode 100644 index 00000000..abdafe24 --- /dev/null +++ b/src/khoj/database/migrations/0097_serverchatsettings_priority.py @@ -0,0 +1,23 @@ +# Generated by Django 5.1.14 on 2025-11-27 01:09 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0096_mcpserver"), + ] + + operations = [ + migrations.AddField( + model_name="serverchatsettings", + name="priority", + field=models.IntegerField( + blank=True, + default=None, + help_text="Priority of the server chat settings. Lower numbers run first.", + null=True, + unique=True, + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 4d759506..b7f869a5 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -468,6 +468,16 @@ class WebScraper(DbBaseModel): class ServerChatSettings(DbBaseModel): + class ChatModelSlot(models.TextChoices): + """Enum for the different chat model slots in ServerChatSettings""" + + CHAT_DEFAULT = "chat_default" + CHAT_ADVANCED = "chat_advanced" + THINK_FREE_FAST = "think_free_fast" + THINK_FREE_DEEP = "think_free_deep" + THINK_PAID_FAST = "think_paid_fast" + THINK_PAID_DEEP = "think_paid_deep" + chat_default = models.ForeignKey( ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default" ) @@ -489,6 +499,13 @@ class ServerChatSettings(DbBaseModel): web_scraper = models.ForeignKey( WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper" ) + priority = models.IntegerField( + default=None, + null=True, + blank=True, + unique=True, + help_text="Priority of the server chat settings. Lower numbers run first.", + ) def clean(self): error = {} @@ -503,6 +520,11 @@ class ServerChatSettings(DbBaseModel): def save(self, *args, **kwargs): self.clean() + + if self.priority is None: + max_priority = ServerChatSettings.objects.aggregate(models.Max("priority"))["priority__max"] + self.priority = max_priority + 1 if max_priority else 1 + super().save(*args, **kwargs) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 62d6aaba..82482fb0 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -24,7 +24,7 @@ from khoj.utils.helpers import ToolDefinition logger = logging.getLogger(__name__) -def send_message_to_model( +def openai_send_message_to_model( messages, api_key, model: str, diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 9218227d..4f4bd358 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -84,7 +84,7 @@ def _extract_text_for_instructions(content: Union[str, List, Dict, None]) -> str | retry_if_exception_type(ValueError) ), wait=wait_random_exponential(min=1, max=10), - stop=stop_after_attempt(3), + stop=stop_after_attempt(2), before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) @@ -438,7 +438,7 @@ async def chat_completion_with_backoff( | retry_if_exception_type(ValueError) ), wait=wait_random_exponential(min=1, max=10), - stop=stop_after_attempt(3), + stop=stop_after_attempt(2), before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e63d2e71..5672134f 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -11,10 +11,14 @@ from enum import Enum from io import BytesIO from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +import httpx import PIL.Image import pyjson5 import requests import yaml +from anthropic import APIError as AnthropicAPIError +from anthropic import RateLimitError as AnthropicRateLimitError +from google.genai import errors as gerrors from langchain_core.messages.chat import ChatMessage from pydantic import BaseModel, ConfigDict, ValidationError @@ -93,6 +97,66 @@ model_to_prompt_size = { model_to_tokenizer: Dict[str, str] = {} +class RetryableModelError(Exception): + """ + Exception raised when a chat model fails with a retryable error. + This is used to trigger fallback to the next model in the priority list. + + Wraps provider-specific retryable errors like: + - OpenAI: RateLimitError, InternalServerError, APITimeoutError + - Anthropic: RateLimitError, APIError + - Google/Gemini: API errors with codes 429, 502, 503, 504 + """ + + def __init__(self, message: str, original_exception: Exception = None, model_name: str = None): + super().__init__(message) + self.original_exception = original_exception + self.model_name = model_name + + def __str__(self): + model_info = f" (model: {self.model_name})" if self.model_name else "" + return f"{super().__str__()}{model_info}" + + +def is_retryable_exception(exception: BaseException) -> bool: + """ + Check if an exception is retryable and should trigger fallback to another model. + """ + # OpenAI exceptions + if hasattr(exception, "__module__") and exception.__module__ and "openai" in exception.__module__: + import openai + + if isinstance( + exception, + ( + openai._exceptions.APITimeoutError, + openai._exceptions.RateLimitError, + openai._exceptions.InternalServerError, + ), + ): + return True + + # Anthropic exceptions + if isinstance(exception, (AnthropicRateLimitError, AnthropicAPIError)): + return True + + # Google/Gemini exceptions + if isinstance(exception, (gerrors.APIError, gerrors.ClientError)): + # Check for specific error codes that are retryable + if hasattr(exception, "code") and exception.code in [429, 502, 503, 504]: + return True + + # Network errors + if isinstance(exception, (httpx.TimeoutException, httpx.NetworkError)): + return True + + # Empty or no response by model over API results in ValueError + if isinstance(exception, ValueError): + return True + + return False + + class AgentMessage(BaseModel): role: Literal["user", "assistant", "system", "environment"] content: Union[str, List] diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 8e64c1ea..e5d2de57 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -89,19 +89,21 @@ from khoj.processor.conversation.google.gemini_chat import ( ) from khoj.processor.conversation.openai.gpt import ( converse_openai, - send_message_to_model, + openai_send_message_to_model, ) from khoj.processor.conversation.utils import ( ChatEvent, OperatorRun, ResearchIteration, ResponseWithThought, + RetryableModelError, clean_json, clean_mermaidjs, construct_chat_history, construct_question_history, defilter_query, generate_chatml_messages_with_context, + is_retryable_exception, ) from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled from khoj.routers.email import is_resend_enabled, send_task_email @@ -1452,61 +1454,26 @@ async def execute_search( return results -async def send_message_to_model_wrapper( - # Context - query: str, - query_files: str = None, - query_images: List[str] = None, - context: str = "", - chat_history: list[ChatMessageModel] = [], - system_message: str = "", - # Model Config - response_type: str = "text", - response_schema: BaseModel = None, - tools: List[ToolDefinition] = None, - deepthought: bool = False, - fast_model: Optional[bool] = None, - agent_chat_model: ChatModel = None, - # User - user: KhojUser = None, - # Tracer - tracer: dict = {}, -): - chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model, fast=fast_model) - vision_available = chat_model.vision_enabled - if not vision_available and query_images: - logger.warning(f"Vision is not enabled for default model: {chat_model.name}.") - vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() - if vision_enabled_config: - chat_model = vision_enabled_config - vision_available = True - if vision_available and query_images: - logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.") - - max_tokens = await ConversationAdapters.aget_max_context_size(chat_model, user) - chat_model_name = chat_model.name - tokenizer = chat_model.tokenizer +def send_message_to_model( + chat_model: ChatModel, + truncated_messages: List[ChatMessage], + response_type: str, + response_schema: BaseModel, + tools: List[ToolDefinition], + deepthought: bool, + tracer: dict, +) -> ResponseWithThought: + """ + Call a specific chat model with the given parameters. + This is a helper function used by send_message_to_model_wrapper for the fallback loop. + """ model_type = chat_model.model_type - vision_available = chat_model.vision_enabled + chat_model_name = chat_model.name api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url - truncated_messages = generate_chatml_messages_with_context( - user_message=query, - query_files=query_files, - query_images=query_images, - context_message=context, - chat_history=chat_history, - system_message=system_message, - model_name=chat_model_name, - model_type=model_type, - tokenizer_name=tokenizer, - max_prompt_size=max_tokens, - vision_enabled=vision_available, - ) - if model_type == ChatModel.ModelType.OPENAI: - return send_message_to_model( + return openai_send_message_to_model( messages=truncated_messages, api_key=api_key, model=chat_model_name, @@ -1542,7 +1509,101 @@ async def send_message_to_model_wrapper( tracer=tracer, ) else: - raise HTTPException(status_code=500, detail="Invalid conversation config") + raise HTTPException(status_code=500, detail=f"Invalid model type: {model_type}") + + +async def send_message_to_model_wrapper( + # Context + query: str, + query_files: str = None, + query_images: List[str] = None, + context: str = "", + chat_history: list[ChatMessageModel] = [], + system_message: str = "", + # Model Config + response_type: str = "text", + response_schema: BaseModel = None, + tools: List[ToolDefinition] = None, + deepthought: bool = False, + fast_model: Optional[bool] = None, + agent_chat_model: ChatModel = None, + # User + user: KhojUser = None, + # Tracer + tracer: dict = {}, +): + # Get primary chat model + primary_chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model( + user, agent_chat_model, fast=fast_model + ) + vision_available = primary_chat_model.vision_enabled + + # Handle vision model override if needed + if not vision_available and query_images: + logger.warning(f"Vision is not enabled for default model: {primary_chat_model.name}.") + vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() + if vision_enabled_config: + primary_chat_model = vision_enabled_config + vision_available = True + if vision_available and query_images: + logger.info(f"Using {primary_chat_model.name} model to understand {len(query_images)} images.") + + # Get fallback models for the appropriate slot + slot = await ConversationAdapters.aget_chat_model_slot(user, fast=fast_model) + fallback_models = await ConversationAdapters.aget_chat_models_with_fallbacks(slot) + # Filter out the primary model from fallbacks to avoid duplicate attempts + fallback_models = [m for m in fallback_models if m.id != primary_chat_model.id] + + # Build list of models to try: primary first, then fallbacks + models_to_try = [primary_chat_model] + fallback_models + + last_exception: Optional[Exception] = None + for i, chat_model in enumerate(models_to_try): + is_last_model = i == len(models_to_try) - 1 + + # Prepare messages for this specific model + max_tokens = await ConversationAdapters.aget_max_context_size(chat_model, user) + truncated_messages = generate_chatml_messages_with_context( + user_message=query, + query_files=query_files, + query_images=query_images, + context_message=context, + chat_history=chat_history, + system_message=system_message, + model_name=chat_model.name, + model_type=chat_model.model_type, + tokenizer_name=chat_model.tokenizer, + max_prompt_size=max_tokens, + vision_enabled=chat_model.vision_enabled if not query_images else vision_available, + ) + + try: + return send_message_to_model( + chat_model=chat_model, + truncated_messages=truncated_messages, + response_type=response_type, + response_schema=response_schema, + tools=tools, + deepthought=deepthought, + tracer=tracer, + ) + except Exception as e: + last_exception = e + if is_retryable_exception(e): + if is_last_model: + logger.error(f"All chat models failed. Last error from {chat_model.name}: {e}") + else: + logger.warning(f"Chat model {chat_model.name} failed with retryable error: {e}. Trying next model.") + continue + # Non-retryable errors should be raised immediately + raise + + # If we get here, all models failed with retryable errors + raise RetryableModelError( + message=f"All {len(models_to_try)} chat models failed", + original_exception=last_exception, + model_name=models_to_try[-1].name if models_to_try else None, + ) def send_message_to_model_wrapper_sync( @@ -1581,7 +1642,7 @@ def send_message_to_model_wrapper_sync( ) if model_type == ChatModel.ModelType.OPENAI: - return send_message_to_model( + return openai_send_message_to_model( messages=truncated_messages, api_key=api_key, api_base_url=api_base_url,