mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Dynamically set default agent chat model to server > user > first chat model
Previously the chat model associated with the default agent was always the first chat model populated on the server. This doesn't match behavior of the rest of the system, where the server chat settings is preferred over the user chat settings over the first chat model. This change brings the default agent's chat model in line with the preference order used in the reset of the system.
This commit is contained in:
@@ -798,6 +798,30 @@ class AgentAdapters:
|
||||
async def aget_default_agent():
|
||||
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
|
||||
|
||||
@staticmethod
|
||||
def get_agent_chat_model(agent: Agent, user: Optional[KhojUser]) -> Optional[ChatModel]:
|
||||
"""
|
||||
Gets the appropriate chat model for an agent.
|
||||
For the default agent, it dynamically determines the model based on user/server settings.
|
||||
For other agents, it returns their statically assigned chat model.
|
||||
Requires the user context to determine the correct default model.
|
||||
"""
|
||||
if agent.slug == AgentAdapters.DEFAULT_AGENT_SLUG:
|
||||
# Dynamically get the default model based on context
|
||||
return ConversationAdapters.get_default_chat_model(user)
|
||||
elif agent.chat_model:
|
||||
# Return the model assigned directly to the specific agent
|
||||
# Ensure the related object is loaded if necessary (prefetching is recommended)
|
||||
return agent.chat_model
|
||||
else:
|
||||
# Fallback if agent has no unset chat_model. For example if chat_model associated with agent was deleted.
|
||||
logger.warning(f"Agent {agent.slug} has no chat_model or agent is None, returning overall default.")
|
||||
return ConversationAdapters.get_default_chat_model(user)
|
||||
|
||||
@staticmethod
|
||||
async def aget_agent_chat_model(agent: Agent, user: Optional[KhojUser]) -> Optional[ChatModel]:
|
||||
return await sync_to_async(AgentAdapters.get_agent_chat_model)(agent, user)
|
||||
|
||||
@staticmethod
|
||||
@arequire_valid_user
|
||||
async def aupdate_agent(
|
||||
|
||||
@@ -62,6 +62,7 @@ async def all_agents(
|
||||
for agent in agents:
|
||||
files = agent.fileobject_set.all()
|
||||
file_names = [file.file_name for file in files]
|
||||
agent_chat_model = await AgentAdapters.aget_agent_chat_model(default_agent, user)
|
||||
agent_packet = {
|
||||
"slug": agent.slug,
|
||||
"name": agent.name,
|
||||
@@ -71,7 +72,7 @@ async def all_agents(
|
||||
"color": agent.style_color,
|
||||
"icon": agent.style_icon,
|
||||
"privacy_level": agent.privacy_level,
|
||||
"chat_model": agent.chat_model.name,
|
||||
"chat_model": agent_chat_model.name,
|
||||
"files": file_names,
|
||||
"input_tools": agent.input_tools,
|
||||
"output_modes": agent.output_modes,
|
||||
@@ -125,6 +126,7 @@ async def get_agent_by_conversation(
|
||||
agent = await AgentAdapters.aget_default_agent()
|
||||
|
||||
has_files = agent.fileobject_set.exists()
|
||||
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
|
||||
|
||||
agents_packet = {
|
||||
"slug": agent.slug,
|
||||
@@ -194,6 +196,8 @@ async def get_agent(
|
||||
|
||||
files = agent.fileobject_set.all()
|
||||
file_names = [file.file_name for file in files]
|
||||
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
|
||||
|
||||
agents_packet = {
|
||||
"slug": agent.slug,
|
||||
"name": agent.name,
|
||||
@@ -265,6 +269,7 @@ async def update_hidden_agent(
|
||||
output_modes=body.output_modes,
|
||||
existing_agent=selected_agent,
|
||||
)
|
||||
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
|
||||
|
||||
agents_packet = {
|
||||
"slug": agent.slug,
|
||||
@@ -320,6 +325,7 @@ async def create_hidden_agent(
|
||||
output_modes=body.output_modes,
|
||||
existing_agent=None,
|
||||
)
|
||||
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
|
||||
|
||||
conversation.agent = agent
|
||||
await conversation.asave()
|
||||
@@ -374,6 +380,7 @@ async def create_agent(
|
||||
body.slug,
|
||||
body.is_hidden,
|
||||
)
|
||||
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
|
||||
|
||||
agents_packet = {
|
||||
"slug": agent.slug,
|
||||
@@ -439,6 +446,7 @@ async def update_agent(
|
||||
body.output_modes,
|
||||
body.slug,
|
||||
)
|
||||
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
|
||||
|
||||
agents_packet = {
|
||||
"slug": agent.slug,
|
||||
|
||||
@@ -403,7 +403,7 @@ async def aget_data_sources_and_output_format(
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
|
||||
class PickTools(BaseModel):
|
||||
source: List[str] = Field(..., min_items=1)
|
||||
@@ -492,7 +492,7 @@ async def infer_webpage_urls(
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
|
||||
class WebpageUrls(BaseModel):
|
||||
links: List[str] = Field(..., min_items=1, max_items=max_webpages)
|
||||
@@ -557,7 +557,7 @@ async def generate_online_subqueries(
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
|
||||
class OnlineQueries(BaseModel):
|
||||
queries: List[str] = Field(..., min_items=1, max_items=max_queries)
|
||||
@@ -666,7 +666,7 @@ async def extract_relevant_info(
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
|
||||
response = await send_message_to_model_wrapper(
|
||||
extract_relevant_information,
|
||||
@@ -707,7 +707,7 @@ async def extract_relevant_summary(
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
|
||||
with timer("Chat actor: Extract relevant information from data", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
@@ -878,7 +878,7 @@ async def generate_better_diagram_description(
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
|
||||
with timer("Chat actor: Generate better diagram description", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
@@ -911,7 +911,7 @@ async def generate_excalidraw_diagram_from_description(
|
||||
query=q,
|
||||
)
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
|
||||
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
@@ -1029,7 +1029,7 @@ async def generate_better_mermaidjs_diagram_description(
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
|
||||
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
@@ -1062,7 +1062,7 @@ async def generate_mermaidjs_diagram_from_description(
|
||||
query=q,
|
||||
)
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
|
||||
with timer("Chat actor: Generate Mermaid.js diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
@@ -1132,7 +1132,7 @@ async def generate_better_image_prompt(
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
|
||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
|
||||
@@ -8,7 +8,7 @@ import yaml
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from khoj.database.adapters import EntryAdapters
|
||||
from khoj.database.adapters import AgentAdapters, EntryAdapters
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
@@ -116,7 +116,7 @@ async def apick_next_tool(
|
||||
|
||||
today = datetime.today()
|
||||
location_data = f"{location}" if location else "Unknown"
|
||||
agent_chat_model = agent.chat_model if agent else None
|
||||
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user