mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Set fast, deep think models for intermediary steps via admin panel
Overview Enable improving speed and cost of chat by setting fast, deep think models for intermediate steps and non user facing operations. Details - Allow decoupling default chat models from models used for intermediate steps by setting server chat settings on admin panel - Use deep think models for most intermediate steps like tool selection, subquery construction etc. in default and research mode - Use fast think models for webpage read, chat title setting etc. Faster webpage read should improve conversation latency
This commit is contained in:
@@ -1293,32 +1293,71 @@ class ConversationAdapters:
|
||||
return ChatModel.objects.filter().first()
|
||||
|
||||
@staticmethod
|
||||
async def aget_default_chat_model(user: KhojUser = None, fallback_chat_model: Optional[ChatModel] = None):
|
||||
"""Get default conversation config. Prefer chat model by server admin > agent > user > first created chat model"""
|
||||
async def aget_default_chat_model(
|
||||
user: KhojUser = None, fallback_chat_model: Optional[ChatModel] = None, fast: Optional[bool] = None
|
||||
):
|
||||
"""
|
||||
Get the chat model to use. Prefer chat model by server admin > agent > user > first created chat model
|
||||
|
||||
Fast is a trinary flag to indicate preference for fast, deep or default chat model configured by the server admin.
|
||||
If fast is True, prefer fast models over deep models when both are configured.
|
||||
If fast is False, prefer deep models over fast models when both are configured.
|
||||
If fast is None, do not consider speed preference and use the default model selection logic.
|
||||
|
||||
If fallback_chat_model is provided, it will be used as a fallback if server chat settings are not configured.
|
||||
Else if user settings are found use that.
|
||||
Otherwise the first chat model will be used.
|
||||
"""
|
||||
# Get the server chat settings
|
||||
server_chat_settings: ServerChatSettings = (
|
||||
await ServerChatSettings.objects.filter()
|
||||
.prefetch_related(
|
||||
"chat_default", "chat_default__ai_model_api", "chat_advanced", "chat_advanced__ai_model_api"
|
||||
"chat_default",
|
||||
"chat_default__ai_model_api",
|
||||
"chat_advanced",
|
||||
"chat_advanced__ai_model_api",
|
||||
"think_free_fast",
|
||||
"think_free_fast__ai_model_api",
|
||||
"think_free_deep",
|
||||
"think_free_deep__ai_model_api",
|
||||
"think_paid_fast",
|
||||
"think_paid_fast__ai_model_api",
|
||||
"think_paid_deep",
|
||||
"think_paid_deep__ai_model_api",
|
||||
)
|
||||
.afirst()
|
||||
)
|
||||
is_subscribed = await ais_user_subscribed(user) if user else False
|
||||
|
||||
if server_chat_settings:
|
||||
# If the user is subscribed and the advanced model is enabled, return the advanced model
|
||||
if is_subscribed and server_chat_settings.chat_advanced:
|
||||
return server_chat_settings.chat_advanced
|
||||
# If the default model is set, return it
|
||||
if server_chat_settings.chat_default:
|
||||
return server_chat_settings.chat_default
|
||||
# If the user is subscribed
|
||||
if is_subscribed:
|
||||
# If fast is requested and fast paid model is available
|
||||
if server_chat_settings.think_paid_fast and fast is True:
|
||||
return server_chat_settings.think_paid_fast
|
||||
# Else if fast is not requested and deep paid model is available
|
||||
elif server_chat_settings.think_paid_deep and fast is not None:
|
||||
return server_chat_settings.think_paid_deep
|
||||
# Else if advanced model is available
|
||||
elif server_chat_settings.chat_advanced:
|
||||
return server_chat_settings.chat_advanced
|
||||
else:
|
||||
# If fast is requested and fast free model is available
|
||||
if server_chat_settings.think_free_fast and fast:
|
||||
return server_chat_settings.think_free_fast
|
||||
# Else if fast is not requested and deep free model is available
|
||||
elif server_chat_settings.think_free_deep:
|
||||
return server_chat_settings.think_free_deep
|
||||
# Else if default model is available
|
||||
elif server_chat_settings.chat_default:
|
||||
return server_chat_settings.chat_default
|
||||
|
||||
# Revert to an explicit fallback model if the server chat settings are not set
|
||||
if fallback_chat_model:
|
||||
# The chat model may not be full loaded from the db, so explicitly load it here
|
||||
return await ChatModel.objects.filter(id=fallback_chat_model.id).prefetch_related("ai_model_api").afirst()
|
||||
|
||||
# Get the user's chat settings, if the server chat settings are not set
|
||||
# Get the user's chat settings, if both the server chat settings and the fallback model are not set
|
||||
user_chat_settings = (
|
||||
(await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst())
|
||||
if user
|
||||
|
||||
@@ -278,6 +278,10 @@ class ServerChatSettingsAdmin(unfold_admin.ModelAdmin):
|
||||
list_display = (
|
||||
"chat_default",
|
||||
"chat_advanced",
|
||||
"think_free_fast",
|
||||
"think_free_deep",
|
||||
"think_paid_fast",
|
||||
"think_paid_deep",
|
||||
"web_scraper",
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
# Generated by Django 5.1.10 on 2025-08-26 20:47
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0093_remove_localorgconfig_user_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="serverchatsettings",
|
||||
name="think_free_deep",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="think_free_deep",
|
||||
to="database.chatmodel",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="serverchatsettings",
|
||||
name="think_free_fast",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="think_free_fast",
|
||||
to="database.chatmodel",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="serverchatsettings",
|
||||
name="think_paid_deep",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="think_paid_deep",
|
||||
to="database.chatmodel",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="serverchatsettings",
|
||||
name="think_paid_fast",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="think_paid_fast",
|
||||
to="database.chatmodel",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -472,6 +472,18 @@ class ServerChatSettings(DbBaseModel):
|
||||
chat_advanced = models.ForeignKey(
|
||||
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
|
||||
)
|
||||
think_free_fast = models.ForeignKey(
|
||||
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_free_fast"
|
||||
)
|
||||
think_free_deep = models.ForeignKey(
|
||||
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_free_deep"
|
||||
)
|
||||
think_paid_fast = models.ForeignKey(
|
||||
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_paid_fast"
|
||||
)
|
||||
think_paid_deep = models.ForeignKey(
|
||||
ChatModel, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="think_paid_deep"
|
||||
)
|
||||
web_scraper = models.ForeignKey(
|
||||
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
|
||||
)
|
||||
@@ -480,6 +492,10 @@ class ServerChatSettings(DbBaseModel):
|
||||
error = {}
|
||||
if self.chat_default and self.chat_default.price_tier != PriceTier.FREE:
|
||||
error["chat_default"] = "Set the price tier of this chat model to free or use a free tier chat model."
|
||||
if self.think_free_fast and self.think_free_fast.price_tier != PriceTier.FREE:
|
||||
error["think_free_fast"] = "Set the price tier of this chat model to free or use a free tier chat model."
|
||||
if self.think_free_deep and self.think_free_deep.price_tier != PriceTier.FREE:
|
||||
error["think_free_deep"] = "Set the price tier of this chat model to free or use a free tier chat model."
|
||||
if error:
|
||||
raise ValidationError(error)
|
||||
|
||||
|
||||
@@ -160,6 +160,7 @@ async def generate_python_code(
|
||||
query_files=query_files,
|
||||
user=user,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
@@ -286,7 +286,7 @@ async def acreate_title_from_history(
|
||||
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
|
||||
|
||||
with timer("Chat actor: Generate title from conversation history", logger):
|
||||
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
||||
response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True)
|
||||
|
||||
return response.text.strip()
|
||||
|
||||
@@ -298,7 +298,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
||||
title_generation_prompt = prompts.subject_generation.format(query=query)
|
||||
|
||||
with timer("Chat actor: Generate title from query", logger):
|
||||
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
||||
response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True)
|
||||
|
||||
return response.text.strip()
|
||||
|
||||
@@ -321,7 +321,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
|
||||
|
||||
with timer("Chat actor: Check if safe prompt", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck
|
||||
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck, fast_model=True
|
||||
)
|
||||
|
||||
response = response.text.strip()
|
||||
@@ -410,6 +410,7 @@ async def aget_data_sources_and_output_format(
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -499,6 +500,7 @@ async def infer_webpage_urls(
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -564,6 +566,7 @@ async def generate_online_subqueries(
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -625,7 +628,12 @@ async def aschedule_query(
|
||||
)
|
||||
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
|
||||
crontime_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
fast_model=False,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
@@ -666,6 +674,7 @@ async def extract_relevant_info(
|
||||
prompts.system_prompt_extract_relevant_information,
|
||||
user=user,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=True,
|
||||
tracer=tracer,
|
||||
)
|
||||
return response.text.strip()
|
||||
@@ -709,6 +718,7 @@ async def extract_relevant_summary(
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=True,
|
||||
tracer=tracer,
|
||||
)
|
||||
return response.text.strip()
|
||||
@@ -880,6 +890,7 @@ async def generate_better_diagram_description(
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
tracer=tracer,
|
||||
)
|
||||
response = response.text.strip()
|
||||
@@ -908,7 +919,11 @@ async def generate_excalidraw_diagram_from_description(
|
||||
|
||||
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
|
||||
query=excalidraw_diagram_generation,
|
||||
user=user,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
tracer=tracer,
|
||||
)
|
||||
raw_response_text = clean_json(raw_response.text)
|
||||
try:
|
||||
@@ -1031,6 +1046,7 @@ async def generate_better_mermaidjs_diagram_description(
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
tracer=tracer,
|
||||
)
|
||||
response_text = response.text.strip()
|
||||
@@ -1059,7 +1075,11 @@ async def generate_mermaidjs_diagram_from_description(
|
||||
|
||||
with timer("Chat actor: Generate Mermaid.js diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
|
||||
query=mermaidjs_diagram_generation,
|
||||
user=user,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
tracer=tracer,
|
||||
)
|
||||
return clean_mermaidjs(raw_response.text.strip())
|
||||
|
||||
@@ -1120,6 +1140,7 @@ async def generate_better_image_prompt(
|
||||
query_files=query_files,
|
||||
chat_history=conversation_history,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
user=user,
|
||||
response_type="json_object",
|
||||
response_schema=ImagePromptResponse,
|
||||
@@ -1302,6 +1323,7 @@ async def extract_questions(
|
||||
query_files=query_files,
|
||||
response_type="json_object",
|
||||
response_schema=DocumentQueries,
|
||||
fast_model=False,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
@@ -1420,6 +1442,7 @@ async def send_message_to_model_wrapper(
|
||||
response_schema: BaseModel = None,
|
||||
tools: List[ToolDefinition] = None,
|
||||
deepthought: bool = False,
|
||||
fast_model: Optional[bool] = None,
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
context: str = "",
|
||||
@@ -1428,7 +1451,7 @@ async def send_message_to_model_wrapper(
|
||||
agent_chat_model: ChatModel = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model)
|
||||
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model, fast=fast_model)
|
||||
vision_available = chat_model.vision_enabled
|
||||
if not vision_available and query_images:
|
||||
logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")
|
||||
|
||||
@@ -163,6 +163,7 @@ async def apick_next_tool(
|
||||
chat_history=chat_and_research_history,
|
||||
tools=tools,
|
||||
deepthought=True,
|
||||
fast_model=False,
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
|
||||
Reference in New Issue
Block a user