Set fast, deep think models for intermediary steps via admin panel

Overview
Enable improving speed and cost of chat by setting fast, deep think
models for intermediate steps and non user facing operations.

Details
- Allow decoupling default chat models from models used for
  intermediate steps by setting server chat settings on admin panel
- Use deep think models for most intermediate steps like tool
  selection, subquery construction etc. in default and research mode
- Use fast think models for webpage read, chat title setting etc.
  Faster webpage read should improve conversation latency
This commit is contained in:
Debanjum
2025-08-26 14:44:51 -07:00
parent a99eb841ff
commit 4976b244a4
7 changed files with 162 additions and 17 deletions

View File

@@ -1293,32 +1293,71 @@ class ConversationAdapters:
return ChatModel.objects.filter().first()
@staticmethod
async def aget_default_chat_model(user: KhojUser = None, fallback_chat_model: Optional[ChatModel] = None):
"""Get default conversation config. Prefer chat model by server admin > agent > user > first created chat model"""
async def aget_default_chat_model(
user: KhojUser = None, fallback_chat_model: Optional[ChatModel] = None, fast: Optional[bool] = None
):
"""
Get the chat model to use. Prefer chat model by server admin > agent > user > first created chat model
Fast is a trinary flag to indicate preference for fast, deep or default chat model configured by the server admin.
If fast is True, prefer fast models over deep models when both are configured.
If fast is False, prefer deep models over fast models when both are configured.
If fast is None, do not consider speed preference and use the default model selection logic.
If fallback_chat_model is provided, it will be used as a fallback if server chat settings are not configured.
Else if user settings are found use that.
Otherwise the first chat model will be used.
"""
# Get the server chat settings
server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter()
.prefetch_related(
"chat_default", "chat_default__ai_model_api", "chat_advanced", "chat_advanced__ai_model_api"
"chat_default",
"chat_default__ai_model_api",
"chat_advanced",
"chat_advanced__ai_model_api",
"think_free_fast",
"think_free_fast__ai_model_api",
"think_free_deep",
"think_free_deep__ai_model_api",
"think_paid_fast",
"think_paid_fast__ai_model_api",
"think_paid_deep",
"think_paid_deep__ai_model_api",
)
.afirst()
)
is_subscribed = await ais_user_subscribed(user) if user else False
if server_chat_settings:
# If the user is subscribed and the advanced model is enabled, return the advanced model
if is_subscribed and server_chat_settings.chat_advanced:
return server_chat_settings.chat_advanced
# If the default model is set, return it
if server_chat_settings.chat_default:
return server_chat_settings.chat_default
# If the user is subscribed
if is_subscribed:
# If fast is requested and fast paid model is available
if server_chat_settings.think_paid_fast and fast is True:
return server_chat_settings.think_paid_fast
# Else if fast is not requested and deep paid model is available
elif server_chat_settings.think_paid_deep and fast is not None:
return server_chat_settings.think_paid_deep
# Else if advanced model is available
elif server_chat_settings.chat_advanced:
return server_chat_settings.chat_advanced
else:
# If fast is requested and fast free model is available
if server_chat_settings.think_free_fast and fast:
return server_chat_settings.think_free_fast
# Else if fast is not requested and deep free model is available
elif server_chat_settings.think_free_deep:
return server_chat_settings.think_free_deep
# Else if default model is available
elif server_chat_settings.chat_default:
return server_chat_settings.chat_default
# Revert to an explicit fallback model if the server chat settings are not set
if fallback_chat_model:
# The chat model may not be full loaded from the db, so explicitly load it here
return await ChatModel.objects.filter(id=fallback_chat_model.id).prefetch_related("ai_model_api").afirst()
# Get the user's chat settings, if the server chat settings are not set
# Get the user's chat settings, if both the server chat settings and the fallback model are not set
user_chat_settings = (
(await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst())
if user

View File

@@ -278,6 +278,10 @@ class ServerChatSettingsAdmin(unfold_admin.ModelAdmin):
list_display = (
"chat_default",
"chat_advanced",
"think_free_fast",
"think_free_deep",
"think_paid_fast",
"think_paid_deep",
"web_scraper",
)

View File

@@ -0,0 +1,61 @@
# Generated by Django 5.1.10 on 2025-08-26 20:47
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0093_remove_localorgconfig_user_and_more"),
]
operations = [
migrations.AddField(
model_name="serverchatsettings",
name="think_free_deep",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="think_free_deep",
to="database.chatmodel",
),
),
migrations.AddField(
model_name="serverchatsettings",
name="think_free_fast",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="think_free_fast",
to="database.chatmodel",
),
),
migrations.AddField(
model_name="serverchatsettings",
name="think_paid_deep",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="think_paid_deep",
to="database.chatmodel",
),
),
migrations.AddField(
model_name="serverchatsettings",
name="think_paid_fast",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="think_paid_fast",
to="database.chatmodel",
),
),
]

View File

@@ -472,6 +472,18 @@ class ServerChatSettings(DbBaseModel):
chat_advanced = models.ForeignKey(
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
)
think_free_fast = models.ForeignKey(
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_free_fast"
)
think_free_deep = models.ForeignKey(
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_free_deep"
)
think_paid_fast = models.ForeignKey(
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_paid_fast"
)
think_paid_deep = models.ForeignKey(
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_paid_deep"
)
web_scraper = models.ForeignKey(
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
)
@@ -480,6 +492,10 @@ class ServerChatSettings(DbBaseModel):
error = {}
if self.chat_default and self.chat_default.price_tier != PriceTier.FREE:
error["chat_default"] = "Set the price tier of this chat model to free or use a free tier chat model."
if self.think_free_fast and self.think_free_fast.price_tier != PriceTier.FREE:
error["think_free_fast"] = "Set the price tier of this chat model to free or use a free tier chat model."
if self.think_free_deep and self.think_free_deep.price_tier != PriceTier.FREE:
error["think_free_deep"] = "Set the price tier of this chat model to free or use a free tier chat model."
if error:
raise ValidationError(error)

View File

@@ -160,6 +160,7 @@ async def generate_python_code(
query_files=query_files,
user=user,
agent_chat_model=agent_chat_model,
fast_model=False,
tracer=tracer,
)

View File

@@ -286,7 +286,7 @@ async def acreate_title_from_history(
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
with timer("Chat actor: Generate title from conversation history", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True)
return response.text.strip()
@@ -298,7 +298,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
title_generation_prompt = prompts.subject_generation.format(query=query)
with timer("Chat actor: Generate title from query", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True)
return response.text.strip()
@@ -321,7 +321,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
with timer("Chat actor: Check if safe prompt", logger):
response = await send_message_to_model_wrapper(
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck, fast_model=True
)
response = response.text.strip()
@@ -410,6 +410,7 @@ async def aget_data_sources_and_output_format(
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
fast_model=False,
tracer=tracer,
)
@@ -499,6 +500,7 @@ async def infer_webpage_urls(
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
fast_model=False,
tracer=tracer,
)
@@ -564,6 +566,7 @@ async def generate_online_subqueries(
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
fast_model=False,
tracer=tracer,
)
@@ -625,7 +628,12 @@ async def aschedule_query(
)
raw_response = await send_message_to_model_wrapper(
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
crontime_prompt,
query_images=query_images,
response_type="json_object",
fast_model=False,
user=user,
tracer=tracer,
)
# Validate that the response is a non-empty, JSON-serializable list
@@ -666,6 +674,7 @@ async def extract_relevant_info(
prompts.system_prompt_extract_relevant_information,
user=user,
agent_chat_model=agent_chat_model,
fast_model=True,
tracer=tracer,
)
return response.text.strip()
@@ -709,6 +718,7 @@ async def extract_relevant_summary(
user=user,
query_images=query_images,
agent_chat_model=agent_chat_model,
fast_model=True,
tracer=tracer,
)
return response.text.strip()
@@ -880,6 +890,7 @@ async def generate_better_diagram_description(
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
fast_model=False,
tracer=tracer,
)
response = response.text.strip()
@@ -908,7 +919,11 @@ async def generate_excalidraw_diagram_from_description(
with timer("Chat actor: Generate excalidraw diagram", logger):
raw_response = await send_message_to_model_wrapper(
query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
query=excalidraw_diagram_generation,
user=user,
agent_chat_model=agent_chat_model,
fast_model=False,
tracer=tracer,
)
raw_response_text = clean_json(raw_response.text)
try:
@@ -1031,6 +1046,7 @@ async def generate_better_mermaidjs_diagram_description(
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
fast_model=False,
tracer=tracer,
)
response_text = response.text.strip()
@@ -1059,7 +1075,11 @@ async def generate_mermaidjs_diagram_from_description(
with timer("Chat actor: Generate Mermaid.js diagram", logger):
raw_response = await send_message_to_model_wrapper(
query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
query=mermaidjs_diagram_generation,
user=user,
agent_chat_model=agent_chat_model,
fast_model=False,
tracer=tracer,
)
return clean_mermaidjs(raw_response.text.strip())
@@ -1120,6 +1140,7 @@ async def generate_better_image_prompt(
query_files=query_files,
chat_history=conversation_history,
agent_chat_model=agent_chat_model,
fast_model=False,
user=user,
response_type="json_object",
response_schema=ImagePromptResponse,
@@ -1302,6 +1323,7 @@ async def extract_questions(
query_files=query_files,
response_type="json_object",
response_schema=DocumentQueries,
fast_model=False,
user=user,
tracer=tracer,
)
@@ -1420,6 +1442,7 @@ async def send_message_to_model_wrapper(
response_schema: BaseModel = None,
tools: List[ToolDefinition] = None,
deepthought: bool = False,
fast_model: Optional[bool] = None,
user: KhojUser = None,
query_images: List[str] = None,
context: str = "",
@@ -1428,7 +1451,7 @@ async def send_message_to_model_wrapper(
agent_chat_model: ChatModel = None,
tracer: dict = {},
):
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model)
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model, fast=fast_model)
vision_available = chat_model.vision_enabled
if not vision_available and query_images:
logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")

View File

@@ -163,6 +163,7 @@ async def apick_next_tool(
chat_history=chat_and_research_history,
tools=tools,
deepthought=True,
fast_model=False,
user=user,
query_images=query_images,
query_files=query_files,