Dynamically set default agent chat model to server > user > first chat model

Previously the chat model associated with the default agent was always
the first chat model populated on the server. This doesn't match
behavior of the rest of the system, where the server chat settings is
preferred over the user chat settings over the first chat model.

This change brings the default agent's chat model in line with the
preference order used in the reset of the system.
This commit is contained in:
Debanjum
2025-04-09 18:13:45 +05:30
parent 1eb092010c
commit 33665dee50
4 changed files with 45 additions and 13 deletions

View File

@@ -798,6 +798,30 @@ class AgentAdapters:
async def aget_default_agent():
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
@staticmethod
def get_agent_chat_model(agent: Agent, user: Optional[KhojUser]) -> Optional[ChatModel]:
"""
Gets the appropriate chat model for an agent.
For the default agent, it dynamically determines the model based on user/server settings.
For other agents, it returns their statically assigned chat model.
Requires the user context to determine the correct default model.
"""
if agent.slug == AgentAdapters.DEFAULT_AGENT_SLUG:
# Dynamically get the default model based on context
return ConversationAdapters.get_default_chat_model(user)
elif agent.chat_model:
# Return the model assigned directly to the specific agent
# Ensure the related object is loaded if necessary (prefetching is recommended)
return agent.chat_model
else:
# Fallback if agent has no unset chat_model. For example if chat_model associated with agent was deleted.
logger.warning(f"Agent {agent.slug} has no chat_model or agent is None, returning overall default.")
return ConversationAdapters.get_default_chat_model(user)
@staticmethod
async def aget_agent_chat_model(agent: Agent, user: Optional[KhojUser]) -> Optional[ChatModel]:
return await sync_to_async(AgentAdapters.get_agent_chat_model)(agent, user)
@staticmethod
@arequire_valid_user
async def aupdate_agent(

View File

@@ -62,6 +62,7 @@ async def all_agents(
for agent in agents:
files = agent.fileobject_set.all()
file_names = [file.file_name for file in files]
agent_chat_model = await AgentAdapters.aget_agent_chat_model(default_agent, user)
agent_packet = {
"slug": agent.slug,
"name": agent.name,
@@ -71,7 +72,7 @@ async def all_agents(
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.name,
"chat_model": agent_chat_model.name,
"files": file_names,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
@@ -125,6 +126,7 @@ async def get_agent_by_conversation(
agent = await AgentAdapters.aget_default_agent()
has_files = agent.fileobject_set.exists()
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
agents_packet = {
"slug": agent.slug,
@@ -194,6 +196,8 @@ async def get_agent(
files = agent.fileobject_set.all()
file_names = [file.file_name for file in files]
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
agents_packet = {
"slug": agent.slug,
"name": agent.name,
@@ -265,6 +269,7 @@ async def update_hidden_agent(
output_modes=body.output_modes,
existing_agent=selected_agent,
)
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
agents_packet = {
"slug": agent.slug,
@@ -320,6 +325,7 @@ async def create_hidden_agent(
output_modes=body.output_modes,
existing_agent=None,
)
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
conversation.agent = agent
await conversation.asave()
@@ -374,6 +380,7 @@ async def create_agent(
body.slug,
body.is_hidden,
)
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
agents_packet = {
"slug": agent.slug,
@@ -439,6 +446,7 @@ async def update_agent(
body.output_modes,
body.slug,
)
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
agents_packet = {
"slug": agent.slug,

View File

@@ -403,7 +403,7 @@ async def aget_data_sources_and_output_format(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
class PickTools(BaseModel):
source: List[str] = Field(..., min_items=1)
@@ -492,7 +492,7 @@ async def infer_webpage_urls(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
class WebpageUrls(BaseModel):
links: List[str] = Field(..., min_items=1, max_items=max_webpages)
@@ -557,7 +557,7 @@ async def generate_online_subqueries(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
class OnlineQueries(BaseModel):
queries: List[str] = Field(..., min_items=1, max_items=max_queries)
@@ -666,7 +666,7 @@ async def extract_relevant_info(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
response = await send_message_to_model_wrapper(
extract_relevant_information,
@@ -707,7 +707,7 @@ async def extract_relevant_summary(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
@@ -878,7 +878,7 @@ async def generate_better_diagram_description(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper(
@@ -911,7 +911,7 @@ async def generate_excalidraw_diagram_from_description(
query=q,
)
agent_chat_model = agent.chat_model if agent else None
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Generate excalidraw diagram", logger):
raw_response = await send_message_to_model_wrapper(
@@ -1029,7 +1029,7 @@ async def generate_better_mermaidjs_diagram_description(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
response = await send_message_to_model_wrapper(
@@ -1062,7 +1062,7 @@ async def generate_mermaidjs_diagram_from_description(
query=q,
)
agent_chat_model = agent.chat_model if agent else None
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Generate Mermaid.js diagram", logger):
raw_response = await send_message_to_model_wrapper(
@@ -1132,7 +1132,7 @@ async def generate_better_image_prompt(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(

View File

@@ -8,7 +8,7 @@ import yaml
from fastapi import Request
from pydantic import BaseModel, Field
from khoj.database.adapters import EntryAdapters
from khoj.database.adapters import AgentAdapters, EntryAdapters
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
@@ -116,7 +116,7 @@ async def apick_next_tool(
today = datetime.today()
location_data = f"{location}" if location else "Unknown"
agent_chat_model = agent.chat_model if agent else None
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)