Use the agent CM as a backup chat model when available / necessary. Remove automation as agent option.

This commit is contained in:
sabaimran
2025-02-01 13:06:42 -08:00
parent 641f1bcd91
commit 0645af9b16
5 changed files with 74 additions and 9 deletions

View File

@@ -1008,7 +1008,9 @@ class ConversationAdapters:
if create_new:
return await ConversationAdapters.acreate_conversation_session(user, client_application)
query = Conversation.objects.filter(user=user, client=client_application).prefetch_related("agent")
query = Conversation.objects.filter(user=user, client=client_application).prefetch_related(
"agent", "agent__chat_model"
)
if conversation_id:
return await query.filter(id=conversation_id).afirst()
@@ -1017,7 +1019,7 @@ class ConversationAdapters:
conversation = await query.order_by("-updated_at").afirst()
return conversation or await Conversation.objects.prefetch_related("agent").acreate(
return conversation or await Conversation.objects.prefetch_related("agent", "agent__chat_model").acreate(
user=user, client=client_application
)
@@ -1147,7 +1149,7 @@ class ConversationAdapters:
return ChatModel.objects.filter().first()
@staticmethod
async def aget_default_chat_model(user: KhojUser = None):
async def aget_default_chat_model(user: KhojUser = None, fallback_chat_model: Optional[ChatModel] = None):
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings
server_chat_settings: ServerChatSettings = (
@@ -1167,12 +1169,18 @@ class ConversationAdapters:
if server_chat_settings.chat_default:
return server_chat_settings.chat_default
# Revert to an explicit fallback model if the server chat settings are not set
if fallback_chat_model:
# The chat model may not be full loaded from the db, so explicitly load it here
return await ChatModel.objects.filter(id=fallback_chat_model.id).prefetch_related("ai_model_api").afirst()
# Get the user's chat settings, if the server chat settings are not set
user_chat_settings = (
(await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst())
if user
else None
)
if user_chat_settings is not None and user_chat_settings.setting is not None:
return user_chat_settings.setting

View File

@@ -0,0 +1,24 @@
# Generated by Django 5.0.10 on 2025-02-01 20:10
import django.contrib.postgres.fields
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0084_alter_agent_input_tools"),
]
operations = [
migrations.AlterField(
model_name="agent",
name="output_modes",
field=django.contrib.postgres.fields.ArrayField(
base_field=models.CharField(choices=[("image", "Image"), ("diagram", "Diagram")], max_length=200),
blank=True,
default=list,
null=True,
size=None,
),
),
]

View File

@@ -286,7 +286,6 @@ class Agent(DbBaseModel):
class OutputModeOptions(models.TextChoices):
# These map to various ConversationCommand types
IMAGE = "image"
AUTOMATION = "automation"
DIAGRAM = "diagram"
creator = models.ForeignKey(

View File

@@ -393,12 +393,15 @@ async def aget_data_sources_and_output_format(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
relevant_tools_prompt,
response_type="json_object",
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
@@ -472,6 +475,8 @@ async def infer_webpage_urls(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt,
@@ -479,6 +484,7 @@ async def infer_webpage_urls(
response_type="json_object",
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
@@ -528,6 +534,8 @@ async def generate_online_subqueries(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt,
@@ -535,6 +543,7 @@ async def generate_online_subqueries(
response_type="json_object",
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
@@ -628,10 +637,13 @@ async def extract_relevant_info(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
user=user,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
return response.strip()
@@ -666,12 +678,15 @@ async def extract_relevant_summary(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_summary,
user=user,
query_images=query_images,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
return response.strip()
@@ -834,12 +849,15 @@ async def generate_better_diagram_description(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt,
query_images=query_images,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
response = response.strip()
@@ -864,9 +882,11 @@ async def generate_excalidraw_diagram_from_description(
query=q,
)
agent_chat_model = agent.chat_model if agent else None
with timer("Chat actor: Generate excalidraw diagram", logger):
raw_response = await send_message_to_model_wrapper(
query=excalidraw_diagram_generation, user=user, tracer=tracer
query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
)
raw_response = clean_json(raw_response)
try:
@@ -980,12 +1000,15 @@ async def generate_better_mermaidjs_diagram_description(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt,
query_images=query_images,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
response = response.strip()
@@ -1010,8 +1033,12 @@ async def generate_mermaidjs_diagram_from_description(
query=q,
)
agent_chat_model = agent.chat_model if agent else None
with timer("Chat actor: Generate Mermaid.js diagram", logger):
raw_response = await send_message_to_model_wrapper(query=mermaidjs_diagram_generation, user=user, tracer=tracer)
raw_response = await send_message_to_model_wrapper(
query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
)
return clean_mermaidjs(raw_response.strip())
@@ -1072,9 +1099,16 @@ async def generate_better_image_prompt(
personality_context=personality_context,
)
agent_chat_model = agent.chat_model if agent else None
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(
image_prompt, query_images=query_images, user=user, query_files=query_files, tracer=tracer
image_prompt,
query_images=query_images,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
tracer=tracer,
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@@ -1091,9 +1125,10 @@ async def send_message_to_model_wrapper(
query_images: List[str] = None,
context: str = "",
query_files: str = None,
agent_chat_model: ChatModel = None,
tracer: dict = {},
):
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user)
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model)
vision_available = chat_model.vision_enabled
if not vision_available and query_images:
logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")

View File

@@ -386,7 +386,6 @@ mode_descriptions_for_llm = {
mode_descriptions_for_agent = {
ConversationCommand.Image: "Agent can generate images in response. It cannot not use this to generate charts and graphs.",
ConversationCommand.Automation: "Agent can schedule a task to run at a scheduled date, time and frequency in response.",
ConversationCommand.Diagram: "Agent can generate a visual representation that requires primitives like lines, rectangles, and text.",
}