From 0645af9b168ed5a0461beac43add257b7ded9e79 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 1 Feb 2025 13:06:42 -0800 Subject: [PATCH] Use the agent CM as a backup chat model when available / necessary. Remove automation as agent option. --- src/khoj/database/adapters/__init__.py | 14 ++++-- .../0085_alter_agent_output_modes.py | 24 +++++++++++ src/khoj/database/models/__init__.py | 1 - src/khoj/routers/helpers.py | 43 +++++++++++++++++-- src/khoj/utils/helpers.py | 1 - 5 files changed, 74 insertions(+), 9 deletions(-) create mode 100644 src/khoj/database/migrations/0085_alter_agent_output_modes.py diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index e50ea29a..953171db 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1008,7 +1008,9 @@ class ConversationAdapters: if create_new: return await ConversationAdapters.acreate_conversation_session(user, client_application) - query = Conversation.objects.filter(user=user, client=client_application).prefetch_related("agent") + query = Conversation.objects.filter(user=user, client=client_application).prefetch_related( + "agent", "agent__chat_model" + ) if conversation_id: return await query.filter(id=conversation_id).afirst() @@ -1017,7 +1019,7 @@ class ConversationAdapters: conversation = await query.order_by("-updated_at").afirst() - return conversation or await Conversation.objects.prefetch_related("agent").acreate( + return conversation or await Conversation.objects.prefetch_related("agent", "agent__chat_model").acreate( user=user, client=client_application ) @@ -1147,7 +1149,7 @@ class ConversationAdapters: return ChatModel.objects.filter().first() @staticmethod - async def aget_default_chat_model(user: KhojUser = None): + async def aget_default_chat_model(user: KhojUser = None, fallback_chat_model: Optional[ChatModel] = None): """Get default conversation config. Prefer chat model by server admin > user > first created chat model""" # Get the server chat settings server_chat_settings: ServerChatSettings = ( @@ -1167,12 +1169,18 @@ class ConversationAdapters: if server_chat_settings.chat_default: return server_chat_settings.chat_default + # Revert to an explicit fallback model if the server chat settings are not set + if fallback_chat_model: + # The chat model may not be full loaded from the db, so explicitly load it here + return await ChatModel.objects.filter(id=fallback_chat_model.id).prefetch_related("ai_model_api").afirst() + # Get the user's chat settings, if the server chat settings are not set user_chat_settings = ( (await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__ai_model_api").afirst()) if user else None ) + if user_chat_settings is not None and user_chat_settings.setting is not None: return user_chat_settings.setting diff --git a/src/khoj/database/migrations/0085_alter_agent_output_modes.py b/src/khoj/database/migrations/0085_alter_agent_output_modes.py new file mode 100644 index 00000000..ccf88c28 --- /dev/null +++ b/src/khoj/database/migrations/0085_alter_agent_output_modes.py @@ -0,0 +1,24 @@ +# Generated by Django 5.0.10 on 2025-02-01 20:10 + +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0084_alter_agent_input_tools"), + ] + + operations = [ + migrations.AlterField( + model_name="agent", + name="output_modes", + field=django.contrib.postgres.fields.ArrayField( + base_field=models.CharField(choices=[("image", "Image"), ("diagram", "Diagram")], max_length=200), + blank=True, + default=list, + null=True, + size=None, + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 51bc8e27..f366c15a 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -286,7 +286,6 @@ class Agent(DbBaseModel): class OutputModeOptions(models.TextChoices): # These map to various ConversationCommand types IMAGE = "image" - AUTOMATION = "automation" DIAGRAM = "diagram" creator = models.ForeignKey( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 8339323e..3c65bce3 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -393,12 +393,15 @@ async def aget_data_sources_and_output_format( personality_context=personality_context, ) + agent_chat_model = agent.chat_model if agent else None + with timer("Chat actor: Infer information sources to refer", logger): response = await send_message_to_model_wrapper( relevant_tools_prompt, response_type="json_object", user=user, query_files=query_files, + agent_chat_model=agent_chat_model, tracer=tracer, ) @@ -472,6 +475,8 @@ async def infer_webpage_urls( personality_context=personality_context, ) + agent_chat_model = agent.chat_model if agent else None + with timer("Chat actor: Infer webpage urls to read", logger): response = await send_message_to_model_wrapper( online_queries_prompt, @@ -479,6 +484,7 @@ async def infer_webpage_urls( response_type="json_object", user=user, query_files=query_files, + agent_chat_model=agent_chat_model, tracer=tracer, ) @@ -528,6 +534,8 @@ async def generate_online_subqueries( personality_context=personality_context, ) + agent_chat_model = agent.chat_model if agent else None + with timer("Chat actor: Generate online search subqueries", logger): response = await send_message_to_model_wrapper( online_queries_prompt, @@ -535,6 +543,7 @@ async def generate_online_subqueries( response_type="json_object", user=user, query_files=query_files, + agent_chat_model=agent_chat_model, tracer=tracer, ) @@ -628,10 +637,13 @@ async def extract_relevant_info( personality_context=personality_context, ) + agent_chat_model = agent.chat_model if agent else None + response = await send_message_to_model_wrapper( extract_relevant_information, prompts.system_prompt_extract_relevant_information, user=user, + agent_chat_model=agent_chat_model, tracer=tracer, ) return response.strip() @@ -666,12 +678,15 @@ async def extract_relevant_summary( personality_context=personality_context, ) + agent_chat_model = agent.chat_model if agent else None + with timer("Chat actor: Extract relevant information from data", logger): response = await send_message_to_model_wrapper( extract_relevant_information, prompts.system_prompt_extract_relevant_summary, user=user, query_images=query_images, + agent_chat_model=agent_chat_model, tracer=tracer, ) return response.strip() @@ -834,12 +849,15 @@ async def generate_better_diagram_description( personality_context=personality_context, ) + agent_chat_model = agent.chat_model if agent else None + with timer("Chat actor: Generate better diagram description", logger): response = await send_message_to_model_wrapper( improve_diagram_description_prompt, query_images=query_images, user=user, query_files=query_files, + agent_chat_model=agent_chat_model, tracer=tracer, ) response = response.strip() @@ -864,9 +882,11 @@ async def generate_excalidraw_diagram_from_description( query=q, ) + agent_chat_model = agent.chat_model if agent else None + with timer("Chat actor: Generate excalidraw diagram", logger): raw_response = await send_message_to_model_wrapper( - query=excalidraw_diagram_generation, user=user, tracer=tracer + query=excalidraw_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer ) raw_response = clean_json(raw_response) try: @@ -980,12 +1000,15 @@ async def generate_better_mermaidjs_diagram_description( personality_context=personality_context, ) + agent_chat_model = agent.chat_model if agent else None + with timer("Chat actor: Generate better Mermaid.js diagram description", logger): response = await send_message_to_model_wrapper( improve_diagram_description_prompt, query_images=query_images, user=user, query_files=query_files, + agent_chat_model=agent_chat_model, tracer=tracer, ) response = response.strip() @@ -1010,8 +1033,12 @@ async def generate_mermaidjs_diagram_from_description( query=q, ) + agent_chat_model = agent.chat_model if agent else None + with timer("Chat actor: Generate Mermaid.js diagram", logger): - raw_response = await send_message_to_model_wrapper(query=mermaidjs_diagram_generation, user=user, tracer=tracer) + raw_response = await send_message_to_model_wrapper( + query=mermaidjs_diagram_generation, user=user, agent_chat_model=agent_chat_model, tracer=tracer + ) return clean_mermaidjs(raw_response.strip()) @@ -1072,9 +1099,16 @@ async def generate_better_image_prompt( personality_context=personality_context, ) + agent_chat_model = agent.chat_model if agent else None + with timer("Chat actor: Generate contextual image prompt", logger): response = await send_message_to_model_wrapper( - image_prompt, query_images=query_images, user=user, query_files=query_files, tracer=tracer + image_prompt, + query_images=query_images, + user=user, + query_files=query_files, + agent_chat_model=agent_chat_model, + tracer=tracer, ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): @@ -1091,9 +1125,10 @@ async def send_message_to_model_wrapper( query_images: List[str] = None, context: str = "", query_files: str = None, + agent_chat_model: ChatModel = None, tracer: dict = {}, ): - chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user) + chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model) vision_available = chat_model.vision_enabled if not vision_available and query_images: logger.warning(f"Vision is not enabled for default model: {chat_model.name}.") diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 32290657..b48436c6 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -386,7 +386,6 @@ mode_descriptions_for_llm = { mode_descriptions_for_agent = { ConversationCommand.Image: "Agent can generate images in response. It cannot not use this to generate charts and graphs.", - ConversationCommand.Automation: "Agent can schedule a task to run at a scheduled date, time and frequency in response.", ConversationCommand.Diagram: "Agent can generate a visual representation that requires primitives like lines, rectangles, and text.", }