mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 05:39:11 +00:00
Use the agent CM as a backup chat model when available / necessary. Remove automation as agent option.
This commit is contained in:
@@ -1008,7 +1008,9 @@ class ConversationAdapters:
|
|||||||
if create_new:
|
if create_new:
|
||||||
return await ConversationAdapters.acreate_conversation_session(user, client_application)
|
return await ConversationAdapters.acreate_conversation_session(user, client_application)
|
||||||
|
|
||||||
query = Conversation.objects.filter(user=user, client=client_application).prefetch_related("agent")
|
query = Conversation.objects.filter(user=user, client=client_application).prefetch_related(
|
||||||
|
"agent", "agent__chat_model"
|
||||||
|
)
|
||||||
|
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
return await query.filter(id=conversation_id).afirst()
|
return await query.filter(id=conversation_id).afirst()
|
||||||
@@ -1017,7 +1019,7 @@ class ConversationAdapters:
|
|||||||
|
|
||||||
conversation = await query.order_by("-updated_at").afirst()
|
conversation = await query.order_by("-updated_at").afirst()
|
||||||
|
|
||||||
return conversation or await Conversation.objects.prefetch_related("agent").acreate(
|
return conversation or await Conversation.objects.prefetch_related("agent", "agent__chat_model").acreate(
|
||||||
user=user, client=client_application
|
user=user, client=client_application
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1147,7 +1149,7 @@ class ConversationAdapters:
|
|||||||
return ChatModel.objects.filter().first()
|
return ChatModel.objects.filter().first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_default_chat_model(user: KhojUser = None):
|
async def aget_default_chat_model(user: KhojUser = None, fallback_chat_model: Optional[ChatModel] = None):
|
||||||
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
||||||
# Get the server chat settings
|
# Get the server chat settings
|
||||||
server_chat_settings: ServerChatSettings = (
|
server_chat_settings: ServerChatSettings = (
|
||||||
@@ -1167,12 +1169,18 @@ class ConversationAdapters:
|
|||||||
if server_chat_settings.chat_default:
|
if server_chat_settings.chat_default:
|
||||||
return server_chat_settings.chat_default
|
return server_chat_settings.chat_default
|
||||||
|
|
||||||
|
# Revert to an explicit fallback model if the server chat settings are not set
|
||||||
|
if fallback_chat_model:
|
||||||
|
# The chat model may not be full loaded from the db, so explicitly load it here
|
||||||
|
return await ChatModel.objects.filter(id=fallback_chat_model.id).prefetch_related("ai_model_api").afirst()
|
||||||
|
|
||||||
# Get the user's chat settings, if the server chat settings are not set
|
# Get the user's chat settings, if the server chat settings are not set
|
||||||
user_chat_settings = (
|
user_chat_settings = (
|
||||||
(await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst())
|
(await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst())
|
||||||
if user
|
if user
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_chat_settings is not None and user_chat_settings.setting is not None:
|
if user_chat_settings is not None and user_chat_settings.setting is not None:
|
||||||
return user_chat_settings.setting
|
return user_chat_settings.setting
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
# Generated by Django 5.0.10 on 2025-02-01 20:10
|
||||||
|
|
||||||
|
import django.contrib.postgres.fields
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0084_alter_agent_input_tools"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="agent",
|
||||||
|
name="output_modes",
|
||||||
|
field=django.contrib.postgres.fields.ArrayField(
|
||||||
|
base_field=models.CharField(choices=[("image", "Image"), ("diagram", "Diagram")], max_length=200),
|
||||||
|
blank=True,
|
||||||
|
default=list,
|
||||||
|
null=True,
|
||||||
|
size=None,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -286,7 +286,6 @@ class Agent(DbBaseModel):
|
|||||||
class OutputModeOptions(models.TextChoices):
|
class OutputModeOptions(models.TextChoices):
|
||||||
# These map to various ConversationCommand types
|
# These map to various ConversationCommand types
|
||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
AUTOMATION = "automation"
|
|
||||||
DIAGRAM = "diagram"
|
DIAGRAM = "diagram"
|
||||||
|
|
||||||
creator = models.ForeignKey(
|
creator = models.ForeignKey(
|
||||||
|
|||||||
@@ -393,12 +393,15 @@ async def aget_data_sources_and_output_format(
|
|||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_chat_model = agent.chat_model if agent else None
|
||||||
|
|
||||||
with timer("Chat actor: Infer information sources to refer", logger):
|
with timer("Chat actor: Infer information sources to refer", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
relevant_tools_prompt,
|
relevant_tools_prompt,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -472,6 +475,8 @@ async def infer_webpage_urls(
|
|||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_chat_model = agent.chat_model if agent else None
|
||||||
|
|
||||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt,
|
online_queries_prompt,
|
||||||
@@ -479,6 +484,7 @@ async def infer_webpage_urls(
|
|||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -528,6 +534,8 @@ async def generate_online_subqueries(
|
|||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_chat_model = agent.chat_model if agent else None
|
||||||
|
|
||||||
with timer("Chat actor: Generate online search subqueries", logger):
|
with timer("Chat actor: Generate online search subqueries", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt,
|
online_queries_prompt,
|
||||||
@@ -535,6 +543,7 @@ async def generate_online_subqueries(
|
|||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -628,10 +637,13 @@ async def extract_relevant_info(
|
|||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_chat_model = agent.chat_model if agent else None
|
||||||
|
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_information,
|
prompts.system_prompt_extract_relevant_information,
|
||||||
user=user,
|
user=user,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
@@ -666,12 +678,15 @@ async def extract_relevant_summary(
|
|||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_chat_model = agent.chat_model if agent else None
|
||||||
|
|
||||||
with timer("Chat actor: Extract relevant information from data", logger):
|
with timer("Chat actor: Extract relevant information from data", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_summary,
|
prompts.system_prompt_extract_relevant_summary,
|
||||||
user=user,
|
user=user,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
@@ -834,12 +849,15 @@ async def generate_better_diagram_description(
|
|||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_chat_model = agent.chat_model if agent else None
|
||||||
|
|
||||||
with timer("Chat actor: Generate better diagram description", logger):
|
with timer("Chat actor: Generate better diagram description", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
improve_diagram_description_prompt,
|
improve_diagram_description_prompt,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
@@ -864,9 +882,11 @@ async def generate_excalidraw_diagram_from_description(
|
|||||||
query=q,
|
query=q,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_chat_model = agent.chat_model if agent else None
|
||||||
|
|
||||||
with timer("Chat actor: Generate excalidraw diagram", logger):
|
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
query=excalidraw_diagram_generation, user=user, tracer=tracer
|
query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
|
||||||
)
|
)
|
||||||
raw_response = clean_json(raw_response)
|
raw_response = clean_json(raw_response)
|
||||||
try:
|
try:
|
||||||
@@ -980,12 +1000,15 @@ async def generate_better_mermaidjs_diagram_description(
|
|||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_chat_model = agent.chat_model if agent else None
|
||||||
|
|
||||||
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
|
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
improve_diagram_description_prompt,
|
improve_diagram_description_prompt,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
@@ -1010,8 +1033,12 @@ async def generate_mermaidjs_diagram_from_description(
|
|||||||
query=q,
|
query=q,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_chat_model = agent.chat_model if agent else None
|
||||||
|
|
||||||
with timer("Chat actor: Generate Mermaid.js diagram", logger):
|
with timer("Chat actor: Generate Mermaid.js diagram", logger):
|
||||||
raw_response = await send_message_to_model_wrapper(query=mermaidjs_diagram_generation, user=user, tracer=tracer)
|
raw_response = await send_message_to_model_wrapper(
|
||||||
|
query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer
|
||||||
|
)
|
||||||
return clean_mermaidjs(raw_response.strip())
|
return clean_mermaidjs(raw_response.strip())
|
||||||
|
|
||||||
|
|
||||||
@@ -1072,9 +1099,16 @@ async def generate_better_image_prompt(
|
|||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_chat_model = agent.chat_model if agent else None
|
||||||
|
|
||||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
image_prompt, query_images=query_images, user=user, query_files=query_files, tracer=tracer
|
image_prompt,
|
||||||
|
query_images=query_images,
|
||||||
|
user=user,
|
||||||
|
query_files=query_files,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
@@ -1091,9 +1125,10 @@ async def send_message_to_model_wrapper(
|
|||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
context: str = "",
|
context: str = "",
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
|
agent_chat_model: ChatModel = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user)
|
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model)
|
||||||
vision_available = chat_model.vision_enabled
|
vision_available = chat_model.vision_enabled
|
||||||
if not vision_available and query_images:
|
if not vision_available and query_images:
|
||||||
logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")
|
logger.warning(f"Vision is not enabled for default model: {chat_model.name}.")
|
||||||
|
|||||||
@@ -386,7 +386,6 @@ mode_descriptions_for_llm = {
|
|||||||
|
|
||||||
mode_descriptions_for_agent = {
|
mode_descriptions_for_agent = {
|
||||||
ConversationCommand.Image: "Agent can generate images in response. It cannot not use this to generate charts and graphs.",
|
ConversationCommand.Image: "Agent can generate images in response. It cannot not use this to generate charts and graphs.",
|
||||||
ConversationCommand.Automation: "Agent can schedule a task to run at a scheduled date, time and frequency in response.",
|
|
||||||
ConversationCommand.Diagram: "Agent can generate a visual representation that requires primitives like lines, rectangles, and text.",
|
ConversationCommand.Diagram: "Agent can generate a visual representation that requires primitives like lines, rectangles, and text.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user