Support fallback deep, fast chat models via server chat settings

Overview
---
This change enables specifying fallback chat models for each task
type (fast, deep, default) and user type (free, paid).

Previously we did not fallback to other chat models if the chat model
assigned for a task failed.

Details
---
You can now specify multiple ServerChatSettings via the Admin Panel
with their usage priority. If the highest priority chat model for the
task, user type fails, the task is assigned to a lower priority chat
model configured for the current user and task type.

This change also reduces the retry attempts for openai chat actor
models from 3 to 2 as:
- multiple fallback server chat settings can now be created. So
  reducing retries with same model reduces latency.
- 2 attempts is inline with retry attempts with other model
  types (gemini, anthropic)
This commit is contained in:
Debanjum
2025-11-26 17:46:29 -08:00
parent 99f16df7e2
commit 731700ac43
8 changed files with 293 additions and 56 deletions

View File

@@ -1419,6 +1419,71 @@ class ConversationAdapters:
max_tokens = chat_model.max_prompt_size
return max_tokens
@staticmethod
async def aget_chat_models_with_fallbacks(slot: ServerChatSettings.ChatModelSlot) -> list[ChatModel]:
"""
Get chat models for a specific subscription, speed preference from all ServerChatSettings, ordered by priority.
Used for fallback logic when a chat model fails.
Args:
slot: The chat model slot to get based on user subscription, speed preference (e.g., THINK_FREE_FAST, CHAT_DEFAULT)
Returns:
List of ChatModel objects ordered by ServerChatSettings priority (lower first)
"""
# Map slot enum to field name and prefetch related
slot_field = slot.value
prefetch_fields = [slot_field, f"{slot_field}__ai_model_api"]
# Get all server chat settings ordered by priority
all_settings = [
settings
async for settings in ServerChatSettings.objects.filter()
.prefetch_related(*prefetch_fields)
.order_by("priority")
.aiterator()
]
# Extract the chat model for the requested slot from each settings
chat_models: list[ChatModel] = []
seen_model_ids: set[int] = set()
for settings in all_settings:
chat_model = getattr(settings, slot_field, None)
if chat_model and chat_model.id not in seen_model_ids:
chat_models.append(chat_model)
seen_model_ids.add(chat_model.id)
return chat_models
@staticmethod
async def aget_chat_model_slot(user: KhojUser = None, fast: Optional[bool] = None):
"""
Determine which chat model slot to use based on user subscription and speed preference.
Args:
user: The user making the request
fast: Trinary flag for speed preference (True=fast, False=deep, None=default)
Returns:
The appropriate ChatModelSlot enum value, or None if no slot matches
"""
is_subscribed = await ais_user_subscribed(user) if user else False
if is_subscribed:
if fast is True:
return ServerChatSettings.ChatModelSlot.THINK_PAID_FAST
elif fast is False:
return ServerChatSettings.ChatModelSlot.THINK_PAID_DEEP
else:
return ServerChatSettings.ChatModelSlot.CHAT_ADVANCED
else:
if fast is True:
return ServerChatSettings.ChatModelSlot.THINK_FREE_FAST
elif fast is False:
return ServerChatSettings.ChatModelSlot.THINK_FREE_DEEP
else:
return ServerChatSettings.ChatModelSlot.CHAT_DEFAULT
@staticmethod
async def aget_server_webscraper():
server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst()

View File

@@ -287,6 +287,7 @@ class SearchModelConfigAdmin(unfold_admin.ModelAdmin):
@admin.register(ServerChatSettings)
class ServerChatSettingsAdmin(unfold_admin.ModelAdmin):
list_display = (
"priority",
"chat_default",
"chat_advanced",
"think_free_fast",
@@ -295,6 +296,7 @@ class ServerChatSettingsAdmin(unfold_admin.ModelAdmin):
"think_paid_deep",
"web_scraper",
)
ordering = ("priority",)
@admin.register(WebScraper)

View File

@@ -0,0 +1,23 @@
# Generated by Django 5.1.14 on 2025-11-27 01:09
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0096_mcpserver"),
]
operations = [
migrations.AddField(
model_name="serverchatsettings",
name="priority",
field=models.IntegerField(
blank=True,
default=None,
help_text="Priority of the server chat settings. Lower numbers run first.",
null=True,
unique=True,
),
),
]

View File

@@ -468,6 +468,16 @@ class WebScraper(DbBaseModel):
class ServerChatSettings(DbBaseModel):
class ChatModelSlot(models.TextChoices):
"""Enum for the different chat model slots in ServerChatSettings"""
CHAT_DEFAULT = "chat_default"
CHAT_ADVANCED = "chat_advanced"
THINK_FREE_FAST = "think_free_fast"
THINK_FREE_DEEP = "think_free_deep"
THINK_PAID_FAST = "think_paid_fast"
THINK_PAID_DEEP = "think_paid_deep"
chat_default = models.ForeignKey(
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
)
@@ -489,6 +499,13 @@ class ServerChatSettings(DbBaseModel):
web_scraper = models.ForeignKey(
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
)
priority = models.IntegerField(
default=None,
null=True,
blank=True,
unique=True,
help_text="Priority of the server chat settings. Lower numbers run first.",
)
def clean(self):
error = {}
@@ -503,6 +520,11 @@ class ServerChatSettings(DbBaseModel):
def save(self, *args, **kwargs):
self.clean()
if self.priority is None:
max_priority = ServerChatSettings.objects.aggregate(models.Max("priority"))["priority__max"]
self.priority = max_priority + 1 if max_priority else 1
super().save(*args, **kwargs)

View File

@@ -24,7 +24,7 @@ from khoj.utils.helpers import ToolDefinition
logger = logging.getLogger(__name__)
def send_message_to_model(
def openai_send_message_to_model(
messages,
api_key,
model: str,

View File

@@ -84,7 +84,7 @@ def _extract_text_for_instructions(content: Union[str, List, Dict, None]) -> str
| retry_if_exception_type(ValueError)
),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
@@ -438,7 +438,7 @@ async def chat_completion_with_backoff(
| retry_if_exception_type(ValueError)
),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)

View File

@@ -11,10 +11,14 @@ from enum import Enum
from io import BytesIO
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import httpx
import PIL.Image
import pyjson5
import requests
import yaml
from anthropic import APIError as AnthropicAPIError
from anthropic import RateLimitError as AnthropicRateLimitError
from google.genai import errors as gerrors
from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel, ConfigDict, ValidationError
@@ -93,6 +97,66 @@ model_to_prompt_size = {
model_to_tokenizer: Dict[str, str] = {}
class RetryableModelError(Exception):
"""
Exception raised when a chat model fails with a retryable error.
This is used to trigger fallback to the next model in the priority list.
Wraps provider-specific retryable errors like:
- OpenAI: RateLimitError, InternalServerError, APITimeoutError
- Anthropic: RateLimitError, APIError
- Google/Gemini: API errors with codes 429, 502, 503, 504
"""
def __init__(self, message: str, original_exception: Exception = None, model_name: str = None):
super().__init__(message)
self.original_exception = original_exception
self.model_name = model_name
def __str__(self):
model_info = f" (model: {self.model_name})" if self.model_name else ""
return f"{super().__str__()}{model_info}"
def is_retryable_exception(exception: BaseException) -> bool:
"""
Check if an exception is retryable and should trigger fallback to another model.
"""
# OpenAI exceptions
if hasattr(exception, "__module__") and exception.__module__ and "openai" in exception.__module__:
import openai
if isinstance(
exception,
(
openai._exceptions.APITimeoutError,
openai._exceptions.RateLimitError,
openai._exceptions.InternalServerError,
),
):
return True
# Anthropic exceptions
if isinstance(exception, (AnthropicRateLimitError, AnthropicAPIError)):
return True
# Google/Gemini exceptions
if isinstance(exception, (gerrors.APIError, gerrors.ClientError)):
# Check for specific error codes that are retryable
if hasattr(exception, "code") and exception.code in [429, 502, 503, 504]:
return True
# Network errors
if isinstance(exception, (httpx.TimeoutException, httpx.NetworkError)):
return True
# Empty or no response by model over API results in ValueError
if isinstance(exception, ValueError):
return True
return False
class AgentMessage(BaseModel):
role: Literal["user", "assistant", "system", "environment"]
content: Union[str, List]

View File

@@ -89,19 +89,21 @@ from khoj.processor.conversation.google.gemini_chat import (
)
from khoj.processor.conversation.openai.gpt import (
converse_openai,
send_message_to_model,
openai_send_message_to_model,
)
from khoj.processor.conversation.utils import (
ChatEvent,
OperatorRun,
ResearchIteration,
ResponseWithThought,
RetryableModelError,
clean_json,
clean_mermaidjs,
construct_chat_history,
construct_question_history,
defilter_query,
generate_chatml_messages_with_context,
is_retryable_exception,
)
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
from khoj.routers.email import is_resend_enabled, send_task_email
@@ -1452,61 +1454,26 @@ async def execute_search(
return results
async def send_message_to_model_wrapper(
# Context
query: str,
query_files: str = None,
query_images: List[str] = None,
context: str = "",
chat_history: list[ChatMessageModel] = [],
system_message: str = "",
# Model Config
response_type: str = "text",
response_schema: BaseModel = None,
tools: List[ToolDefinition] = None,
deepthought: bool = False,
fast_model: Optional[bool] = None,
agent_chat_model: ChatModel = None,
# User
user: KhojUser = None,
# Tracer
tracer: dict = {},
):
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model, fast=fast_model)
vision_available = chat_model.vision_enabled
if not vision_available and query_images:
logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
chat_model = vision_enabled_config
vision_available = True
if vision_available and query_images:
logger.info(f"Using {chat_model.name} model to understand {len(query_images)} images.")
max_tokens = await ConversationAdapters.aget_max_context_size(chat_model, user)
chat_model_name = chat_model.name
tokenizer = chat_model.tokenizer
def send_message_to_model(
chat_model: ChatModel,
truncated_messages: List[ChatMessage],
response_type: str,
response_schema: BaseModel,
tools: List[ToolDefinition],
deepthought: bool,
tracer: dict,
) -> ResponseWithThought:
"""
Call a specific chat model with the given parameters.
This is a helper function used by send_message_to_model_wrapper for the fallback loop.
"""
model_type = chat_model.model_type
vision_available = chat_model.vision_enabled
chat_model_name = chat_model.name
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
query_files=query_files,
query_images=query_images,
context_message=context,
chat_history=chat_history,
system_message=system_message,
model_name=chat_model_name,
model_type=model_type,
tokenizer_name=tokenizer,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
)
if model_type == ChatModel.ModelType.OPENAI:
return send_message_to_model(
return openai_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model_name,
@@ -1542,7 +1509,101 @@ async def send_message_to_model_wrapper(
tracer=tracer,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
raise HTTPException(status_code=500, detail=f"Invalid model type: {model_type}")
async def send_message_to_model_wrapper(
# Context
query: str,
query_files: str = None,
query_images: List[str] = None,
context: str = "",
chat_history: list[ChatMessageModel] = [],
system_message: str = "",
# Model Config
response_type: str = "text",
response_schema: BaseModel = None,
tools: List[ToolDefinition] = None,
deepthought: bool = False,
fast_model: Optional[bool] = None,
agent_chat_model: ChatModel = None,
# User
user: KhojUser = None,
# Tracer
tracer: dict = {},
):
# Get primary chat model
primary_chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(
user, agent_chat_model, fast=fast_model
)
vision_available = primary_chat_model.vision_enabled
# Handle vision model override if needed
if not vision_available and query_images:
logger.warning(f"Vision is not enabled for default model: {primary_chat_model.name}.")
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
primary_chat_model = vision_enabled_config
vision_available = True
if vision_available and query_images:
logger.info(f"Using {primary_chat_model.name} model to understand {len(query_images)} images.")
# Get fallback models for the appropriate slot
slot = await ConversationAdapters.aget_chat_model_slot(user, fast=fast_model)
fallback_models = await ConversationAdapters.aget_chat_models_with_fallbacks(slot)
# Filter out the primary model from fallbacks to avoid duplicate attempts
fallback_models = [m for m in fallback_models if m.id != primary_chat_model.id]
# Build list of models to try: primary first, then fallbacks
models_to_try = [primary_chat_model] + fallback_models
last_exception: Optional[Exception] = None
for i, chat_model in enumerate(models_to_try):
is_last_model = i == len(models_to_try) - 1
# Prepare messages for this specific model
max_tokens = await ConversationAdapters.aget_max_context_size(chat_model, user)
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
query_files=query_files,
query_images=query_images,
context_message=context,
chat_history=chat_history,
system_message=system_message,
model_name=chat_model.name,
model_type=chat_model.model_type,
tokenizer_name=chat_model.tokenizer,
max_prompt_size=max_tokens,
vision_enabled=chat_model.vision_enabled if not query_images else vision_available,
)
try:
return send_message_to_model(
chat_model=chat_model,
truncated_messages=truncated_messages,
response_type=response_type,
response_schema=response_schema,
tools=tools,
deepthought=deepthought,
tracer=tracer,
)
except Exception as e:
last_exception = e
if is_retryable_exception(e):
if is_last_model:
logger.error(f"All chat models failed. Last error from {chat_model.name}: {e}")
else:
logger.warning(f"Chat model {chat_model.name} failed with retryable error: {e}. Trying next model.")
continue
# Non-retryable errors should be raised immediately
raise
# If we get here, all models failed with retryable errors
raise RetryableModelError(
message=f"All {len(models_to_try)} chat models failed",
original_exception=last_exception,
model_name=models_to_try[-1].name if models_to_try else None,
)
def send_message_to_model_wrapper_sync(
@@ -1581,7 +1642,7 @@ def send_message_to_model_wrapper_sync(
)
if model_type == ChatModel.ModelType.OPENAI:
return send_message_to_model(
return openai_send_message_to_model(
messages=truncated_messages,
api_key=api_key,
api_base_url=api_base_url,