diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 2280ea60..7ffbf67a 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -798,6 +798,30 @@ class AgentAdapters: async def aget_default_agent(): return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst() + @staticmethod + def get_agent_chat_model(agent: Agent, user: Optional[KhojUser]) -> Optional[ChatModel]: + """ + Gets the appropriate chat model for an agent. + For the default agent, it dynamically determines the model based on user/server settings. + For other agents, it returns their statically assigned chat model. + Requires the user context to determine the correct default model. + """ + if agent.slug == AgentAdapters.DEFAULT_AGENT_SLUG: + # Dynamically get the default model based on context + return ConversationAdapters.get_default_chat_model(user) + elif agent.chat_model: + # Return the model assigned directly to the specific agent + # Ensure the related object is loaded if necessary (prefetching is recommended) + return agent.chat_model + else: + # Fallback if agent has no unset chat_model. For example if chat_model associated with agent was deleted. + logger.warning(f"Agent {agent.slug} has no chat_model or agent is None, returning overall default.") + return ConversationAdapters.get_default_chat_model(user) + + @staticmethod + async def aget_agent_chat_model(agent: Agent, user: Optional[KhojUser]) -> Optional[ChatModel]: + return await sync_to_async(AgentAdapters.get_agent_chat_model)(agent, user) + @staticmethod @arequire_valid_user async def aupdate_agent( diff --git a/src/khoj/routers/api_agents.py b/src/khoj/routers/api_agents.py index 71514025..f97ee584 100644 --- a/src/khoj/routers/api_agents.py +++ b/src/khoj/routers/api_agents.py @@ -62,6 +62,7 @@ async def all_agents( for agent in agents: files = agent.fileobject_set.all() file_names = [file.file_name for file in files] + agent_chat_model = await AgentAdapters.aget_agent_chat_model(default_agent, user) agent_packet = { "slug": agent.slug, "name": agent.name, @@ -71,7 +72,7 @@ async def all_agents( "color": agent.style_color, "icon": agent.style_icon, "privacy_level": agent.privacy_level, - "chat_model": agent.chat_model.name, + "chat_model": agent_chat_model.name, "files": file_names, "input_tools": agent.input_tools, "output_modes": agent.output_modes, @@ -125,6 +126,7 @@ async def get_agent_by_conversation( agent = await AgentAdapters.aget_default_agent() has_files = agent.fileobject_set.exists() + agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) agents_packet = { "slug": agent.slug, @@ -194,6 +196,8 @@ async def get_agent( files = agent.fileobject_set.all() file_names = [file.file_name for file in files] + agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) + agents_packet = { "slug": agent.slug, "name": agent.name, @@ -265,6 +269,7 @@ async def update_hidden_agent( output_modes=body.output_modes, existing_agent=selected_agent, ) + agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) agents_packet = { "slug": agent.slug, @@ -320,6 +325,7 @@ async def create_hidden_agent( output_modes=body.output_modes, existing_agent=None, ) + agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) conversation.agent = agent await conversation.asave() @@ -374,6 +380,7 @@ async def create_agent( body.slug, body.is_hidden, ) + agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) agents_packet = { "slug": agent.slug, @@ -439,6 +446,7 @@ async def update_agent( body.output_modes, body.slug, ) + agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) agents_packet = { "slug": agent.slug, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index ccf3a7b4..d15c0021 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -403,7 +403,7 @@ async def aget_data_sources_and_output_format( personality_context=personality_context, ) - agent_chat_model = agent.chat_model if agent else None + agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None class PickTools(BaseModel): source: List[str] = Field(..., min_items=1) @@ -492,7 +492,7 @@ async def infer_webpage_urls( personality_context=personality_context, ) - agent_chat_model = agent.chat_model if agent else None + agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None class WebpageUrls(BaseModel): links: List[str] = Field(..., min_items=1, max_items=max_webpages) @@ -557,7 +557,7 @@ async def generate_online_subqueries( personality_context=personality_context, ) - agent_chat_model = agent.chat_model if agent else None + agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None class OnlineQueries(BaseModel): queries: List[str] = Field(..., min_items=1, max_items=max_queries) @@ -666,7 +666,7 @@ async def extract_relevant_info( personality_context=personality_context, ) - agent_chat_model = agent.chat_model if agent else None + agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None response = await send_message_to_model_wrapper( extract_relevant_information, @@ -707,7 +707,7 @@ async def extract_relevant_summary( personality_context=personality_context, ) - agent_chat_model = agent.chat_model if agent else None + agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None with timer("Chat actor: Extract relevant information from data", logger): response = await send_message_to_model_wrapper( @@ -878,7 +878,7 @@ async def generate_better_diagram_description( personality_context=personality_context, ) - agent_chat_model = agent.chat_model if agent else None + agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None with timer("Chat actor: Generate better diagram description", logger): response = await send_message_to_model_wrapper( @@ -911,7 +911,7 @@ async def generate_excalidraw_diagram_from_description( query=q, ) - agent_chat_model = agent.chat_model if agent else None + agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None with timer("Chat actor: Generate excalidraw diagram", logger): raw_response = await send_message_to_model_wrapper( @@ -1029,7 +1029,7 @@ async def generate_better_mermaidjs_diagram_description( personality_context=personality_context, ) - agent_chat_model = agent.chat_model if agent else None + agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None with timer("Chat actor: Generate better Mermaid.js diagram description", logger): response = await send_message_to_model_wrapper( @@ -1062,7 +1062,7 @@ async def generate_mermaidjs_diagram_from_description( query=q, ) - agent_chat_model = agent.chat_model if agent else None + agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None with timer("Chat actor: Generate Mermaid.js diagram", logger): raw_response = await send_message_to_model_wrapper( @@ -1132,7 +1132,7 @@ async def generate_better_image_prompt( personality_context=personality_context, ) - agent_chat_model = agent.chat_model if agent else None + agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None with timer("Chat actor: Generate contextual image prompt", logger): response = await send_message_to_model_wrapper( diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index cc73686e..057d7c9e 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -8,7 +8,7 @@ import yaml from fastapi import Request from pydantic import BaseModel, Field -from khoj.database.adapters import EntryAdapters +from khoj.database.adapters import AgentAdapters, EntryAdapters from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( @@ -116,7 +116,7 @@ async def apick_next_tool( today = datetime.today() location_data = f"{location}" if location else "Unknown" - agent_chat_model = agent.chat_model if agent else None + agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" )