From 7a42042488dfe95a7be413651d1bd62fabe8d44e Mon Sep 17 00:00:00 2001 From: Debanjum Date: Tue, 26 Aug 2025 23:32:17 -0700 Subject: [PATCH] Share context builder for chat final response across model types The context building logic was nearly identical across all model types. This change extracts that logic into a shared function and calls it once in the `agenerate_chat_response', the entrypoint to the converse methods for all 3 model types. Main differences handled are - Gemini system prompt had additional verbosity instructions. Keep it - Pass system messsage via chatml messages list to anthropic, gemini models as well (like openai models) instead of passing it as separate arg to chat_completion_* funcs. The model specific message formatters for both already extract system instruction from the messages list. So system messages wil be automatically extracted from the chat_completion_* funcs to pass as separate arg required by anthropic, gemini api libraries. --- .../conversation/anthropic/anthropic_chat.py | 90 +-------- .../processor/conversation/anthropic/utils.py | 3 +- .../conversation/google/gemini_chat.py | 90 +-------- .../processor/conversation/google/utils.py | 2 +- src/khoj/processor/conversation/openai/gpt.py | 89 +-------- src/khoj/routers/helpers.py | 184 +++++++++++++----- 6 files changed, 145 insertions(+), 313 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 87aed7cd..30259bc0 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -1,22 +1,16 @@ import logging -from datetime import datetime -from typing import AsyncGenerator, Dict, List, Optional +from typing import AsyncGenerator, List, Optional + +from langchain_core.messages.chat import ChatMessage -from khoj.database.models import Agent, ChatMessageModel, ChatModel -from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.utils import ( anthropic_chat_completion_with_backoff, anthropic_completion_with_backoff, ) from khoj.processor.conversation.utils import ( - OperatorRun, ResponseWithThought, - generate_chatml_messages_with_context, messages_to_print, ) -from khoj.utils.helpers import is_none_or_empty, truncate_code_context -from khoj.utils.rawconfig import LocationData -from khoj.utils.yaml import yaml_dump logger = logging.getLogger(__name__) @@ -52,91 +46,17 @@ def anthropic_send_message_to_model( async def converse_anthropic( # Query - user_query: str, - # Context - references: list[dict], - online_results: Optional[Dict[str, Dict]] = None, - code_results: Optional[Dict[str, Dict]] = None, - operator_results: Optional[List[OperatorRun]] = None, - query_images: Optional[list[str]] = None, - query_files: str = None, - program_execution_context: Optional[List[str]] = None, - generated_asset_results: Dict[str, Dict] = {}, - location_data: LocationData = None, - user_name: str = None, - chat_history: List[ChatMessageModel] = [], + messages: List[ChatMessage], # Model model: Optional[str] = "claude-3-7-sonnet-latest", api_key: Optional[str] = None, api_base_url: Optional[str] = None, - max_prompt_size=None, - tokenizer_name=None, - agent: Agent = None, - vision_available: bool = False, deepthought: Optional[bool] = False, tracer: dict = {}, ) -> AsyncGenerator[ResponseWithThought, None]: """ Converse with user using Anthropic's Claude """ - # Initialize Variables - current_date = datetime.now() - - if agent and agent.personality: - system_prompt = prompts.custom_personality.format( - name=agent.name, - bio=agent.personality, - current_date=current_date.strftime("%Y-%m-%d"), - day_of_week=current_date.strftime("%A"), - ) - else: - system_prompt = prompts.personality.format( - current_date=current_date.strftime("%Y-%m-%d"), - day_of_week=current_date.strftime("%A"), - ) - - if location_data: - location_prompt = prompts.user_location.format(location=f"{location_data}") - system_prompt = f"{system_prompt}\n{location_prompt}" - - if user_name: - user_name_prompt = prompts.user_name.format(name=user_name) - system_prompt = f"{system_prompt}\n{user_name_prompt}" - - context_message = "" - if not is_none_or_empty(references): - context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n" - if not is_none_or_empty(online_results): - context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n" - if not is_none_or_empty(code_results): - context_message += ( - f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" - ) - if not is_none_or_empty(operator_results): - operator_content = [ - {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results - ] - context_message += ( - f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n" - ) - context_message = context_message.strip() - - # Setup Prompt with Primer or Conversation History - messages = generate_chatml_messages_with_context( - user_query, - context_message=context_message, - chat_history=chat_history, - model_name=model, - max_prompt_size=max_prompt_size, - tokenizer_name=tokenizer_name, - query_images=query_images, - vision_enabled=vision_available, - model_type=ChatModel.ModelType.ANTHROPIC, - query_files=query_files, - generated_asset_results=generated_asset_results, - program_execution_context=program_execution_context, - ) - logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}") # Get Response from Claude @@ -146,8 +66,6 @@ async def converse_anthropic( temperature=0.2, api_key=api_key, api_base_url=api_base_url, - system_prompt=system_prompt, - max_prompt_size=max_prompt_size, deepthought=deepthought, tracer=tracer, ): diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index ab20f8aa..139b27bc 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -184,8 +184,7 @@ async def anthropic_chat_completion_with_backoff( temperature: float, api_key: str | None, api_base_url: str, - system_prompt: str, - max_prompt_size: int | None = None, + system_prompt: str = "", deepthought: bool = False, model_kwargs: dict | None = None, tracer: dict = {}, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 44245e73..71e414f5 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -1,22 +1,16 @@ import logging -from datetime import datetime -from typing import AsyncGenerator, Dict, List, Optional +from typing import AsyncGenerator, List, Optional + +from langchain_core.messages.chat import ChatMessage -from khoj.database.models import Agent, ChatMessageModel, ChatModel -from khoj.processor.conversation import prompts from khoj.processor.conversation.google.utils import ( gemini_chat_completion_with_backoff, gemini_completion_with_backoff, ) from khoj.processor.conversation.utils import ( - OperatorRun, ResponseWithThought, - generate_chatml_messages_with_context, messages_to_print, ) -from khoj.utils.helpers import is_none_or_empty, truncate_code_context -from khoj.utils.rawconfig import LocationData -from khoj.utils.yaml import yaml_dump logger = logging.getLogger(__name__) @@ -61,93 +55,18 @@ def gemini_send_message_to_model( async def converse_gemini( # Query - user_query: str, - # Context - references: list[dict], - online_results: Optional[Dict[str, Dict]] = None, - code_results: Optional[Dict[str, Dict]] = None, - operator_results: Optional[List[OperatorRun]] = None, - query_images: Optional[list[str]] = None, - query_files: str = None, - generated_asset_results: Dict[str, Dict] = {}, - program_execution_context: List[str] = None, - location_data: LocationData = None, - user_name: str = None, - chat_history: List[ChatMessageModel] = [], + messages: List[ChatMessage], # Model model: Optional[str] = "gemini-2.5-flash", api_key: Optional[str] = None, api_base_url: Optional[str] = None, temperature: float = 1.0, - max_prompt_size=None, - tokenizer_name=None, - agent: Agent = None, - vision_available: bool = False, deepthought: Optional[bool] = False, tracer={}, ) -> AsyncGenerator[ResponseWithThought, None]: """ Converse with user using Google's Gemini """ - # Initialize Variables - current_date = datetime.now() - - if agent and agent.personality: - system_prompt = prompts.custom_personality.format( - name=agent.name, - bio=agent.personality, - current_date=current_date.strftime("%Y-%m-%d"), - day_of_week=current_date.strftime("%A"), - ) - else: - system_prompt = prompts.personality.format( - current_date=current_date.strftime("%Y-%m-%d"), - day_of_week=current_date.strftime("%A"), - ) - - system_prompt += f"\n\n{prompts.gemini_verbose_language_personality}" - if location_data: - location_prompt = prompts.user_location.format(location=f"{location_data}") - system_prompt += f"\n{location_prompt}" - - if user_name: - user_name_prompt = prompts.user_name.format(name=user_name) - system_prompt += f"\n{user_name_prompt}" - - context_message = "" - if not is_none_or_empty(references): - context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n" - if not is_none_or_empty(online_results): - context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n" - if not is_none_or_empty(code_results): - context_message += ( - f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" - ) - if not is_none_or_empty(operator_results): - operator_content = [ - {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results - ] - context_message += ( - f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n" - ) - context_message = context_message.strip() - - # Setup Prompt with Primer or Conversation History - messages = generate_chatml_messages_with_context( - user_query, - context_message=context_message, - chat_history=chat_history, - model_name=model, - max_prompt_size=max_prompt_size, - tokenizer_name=tokenizer_name, - query_images=query_images, - vision_enabled=vision_available, - model_type=ChatModel.ModelType.GOOGLE, - query_files=query_files, - generated_asset_results=generated_asset_results, - program_execution_context=program_execution_context, - ) - logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}") # Get Response from Google AI @@ -157,7 +76,6 @@ async def converse_gemini( temperature=temperature, api_key=api_key, api_base_url=api_base_url, - system_prompt=system_prompt, deepthought=deepthought, tracer=tracer, ): diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 877ee664..de902894 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -308,7 +308,7 @@ async def gemini_chat_completion_with_backoff( temperature: float, api_key: str, api_base_url: str, - system_prompt: str, + system_prompt: str = "", model_kwargs=None, deepthought=False, tracer: dict = {}, diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index d00437bb..0435bf7f 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -1,9 +1,8 @@ import logging -from datetime import datetime from typing import Any, AsyncGenerator, Dict, List, Optional -from khoj.database.models import Agent, ChatMessageModel, ChatModel -from khoj.processor.conversation import prompts +from langchain_core.messages.chat import ChatMessage + from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, clean_response_schema, @@ -15,15 +14,11 @@ from khoj.processor.conversation.openai.utils import ( to_openai_tools, ) from khoj.processor.conversation.utils import ( - OperatorRun, ResponseWithThought, StructuredOutputSupport, - generate_chatml_messages_with_context, messages_to_print, ) -from khoj.utils.helpers import ToolDefinition, is_none_or_empty, truncate_code_context -from khoj.utils.rawconfig import LocationData -from khoj.utils.yaml import yaml_dump +from khoj.utils.helpers import ToolDefinition logger = logging.getLogger(__name__) @@ -96,92 +91,18 @@ def send_message_to_model( async def converse_openai( # Query - user_query: str, - # Context - references: list[dict], - online_results: Optional[Dict[str, Dict]] = None, - code_results: Optional[Dict[str, Dict]] = None, - operator_results: Optional[List[OperatorRun]] = None, - query_images: Optional[list[str]] = None, - query_files: str = None, - generated_asset_results: Dict[str, Dict] = {}, - program_execution_context: List[str] = None, - location_data: LocationData = None, - chat_history: list[ChatMessageModel] = [], + messages: List[ChatMessage], + # Model model: str = "gpt-4.1-mini", api_key: Optional[str] = None, api_base_url: Optional[str] = None, temperature: float = 0.6, - max_prompt_size=None, - tokenizer_name=None, - user_name: str = None, - agent: Agent = None, - vision_available: bool = False, deepthought: Optional[bool] = False, tracer: dict = {}, ) -> AsyncGenerator[ResponseWithThought, None]: """ Converse with user using OpenAI's ChatGPT """ - # Initialize Variables - current_date = datetime.now() - - if agent and agent.personality: - system_prompt = prompts.custom_personality.format( - name=agent.name, - bio=agent.personality, - current_date=current_date.strftime("%Y-%m-%d"), - day_of_week=current_date.strftime("%A"), - ) - else: - system_prompt = prompts.personality.format( - current_date=current_date.strftime("%Y-%m-%d"), - day_of_week=current_date.strftime("%A"), - ) - - if location_data: - location_prompt = prompts.user_location.format(location=f"{location_data}") - system_prompt = f"{system_prompt}\n{location_prompt}" - - if user_name: - user_name_prompt = prompts.user_name.format(name=user_name) - system_prompt = f"{system_prompt}\n{user_name_prompt}" - - context_message = "" - if not is_none_or_empty(references): - context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n" - if not is_none_or_empty(online_results): - context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n" - if not is_none_or_empty(code_results): - context_message += ( - f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" - ) - if not is_none_or_empty(operator_results): - operator_content = [ - {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results - ] - context_message += ( - f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n" - ) - - context_message = context_message.strip() - - # Setup Prompt with Primer or Conversation History - messages = generate_chatml_messages_with_context( - user_query, - system_prompt, - chat_history, - context_message=context_message, - model_name=model, - max_prompt_size=max_prompt_size, - tokenizer_name=tokenizer_name, - query_images=query_images, - vision_enabled=vision_available, - model_type=ChatModel.ModelType.OPENAI, - query_files=query_files, - generated_asset_results=generated_asset_results, - program_execution_context=program_execution_context, - ) logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") # Get Response from GPT diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 80f9fbdd..1f17ffe8 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -33,6 +33,7 @@ from apscheduler.triggers.cron import CronTrigger from asgiref.sync import sync_to_async from django.utils import timezone as django_timezone from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket +from langchain_core.messages.chat import ChatMessage from pydantic import BaseModel, EmailStr, Field from starlette.authentication import has_required_scope from starlette.requests import URL @@ -124,6 +125,7 @@ from khoj.utils.helpers import ( mode_descriptions_for_llm, timer, tool_descriptions_for_llm, + truncate_code_context, ) from khoj.utils.rawconfig import ( ChatRequestBody, @@ -132,6 +134,7 @@ from khoj.utils.rawconfig import ( SearchResponse, ) from khoj.utils.state import SearchType +from khoj.utils.yaml import yaml_dump logger = logging.getLogger(__name__) @@ -1598,6 +1601,105 @@ def send_message_to_model_wrapper_sync( raise HTTPException(status_code=500, detail="Invalid conversation config") +def build_conversation_context( + # Query and Context + user_query: str, + references: List[Dict], + online_results: Dict[str, Dict], + code_results: Dict[str, Dict], + operator_results: List[OperatorRun], + query_files: str = None, + query_images: Optional[List[str]] = None, + generated_asset_results: Dict[str, Dict] = {}, + program_execution_context: List[str] = None, + chat_history: List[ChatMessageModel] = [], + location_data: LocationData = None, + user_name: str = None, + # Model config + agent: Agent = None, + model_name: str = None, + model_type: ChatModel.ModelType = None, + max_prompt_size: int = None, + tokenizer_name: str = None, + vision_available: bool = False, +) -> List[ChatMessage]: + """ + Construct system, context and chatml messages for chat response. + Share common logic across different model types. + + Returns: + List of ChatMessages with context + """ + # Initialize Variables + current_date = datetime.now() + + # Build system prompt + if agent and agent.personality: + system_prompt = prompts.custom_personality.format( + name=agent.name, + bio=agent.personality, + current_date=current_date.strftime("%Y-%m-%d"), + day_of_week=current_date.strftime("%A"), + ) + else: + system_prompt = prompts.personality.format( + current_date=current_date.strftime("%Y-%m-%d"), + day_of_week=current_date.strftime("%A"), + ) + + # Add Gemini-specific personality enhancement + if model_type == ChatModel.ModelType.GOOGLE: + system_prompt += f"\n\n{prompts.gemini_verbose_language_personality}" + + # Add location context if available + if location_data: + location_prompt = prompts.user_location.format(location=f"{location_data}") + system_prompt += f"\n{location_prompt}" + + # Add user name context if available + if user_name: + user_name_prompt = prompts.user_name.format(name=user_name) + system_prompt += f"\n{user_name_prompt}" + + # Build context message + context_message = "" + if not is_none_or_empty(references): + context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n" + if not is_none_or_empty(online_results): + context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n" + if not is_none_or_empty(code_results): + context_message += ( + f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n" + ) + if not is_none_or_empty(operator_results): + operator_content = [ + {"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results + ] + context_message += ( + f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n" + ) + context_message = context_message.strip() + + # Generate the chatml messages + messages = generate_chatml_messages_with_context( + user_message=user_query, + query_files=query_files, + query_images=query_images, + context_message=context_message, + generated_asset_results=generated_asset_results, + program_execution_context=program_execution_context, + chat_history=chat_history, + system_message=system_prompt, + model_name=model_name, + model_type=model_type, + max_prompt_size=max_prompt_size, + tokenizer_name=tokenizer_name, + vision_enabled=vision_available, + ) + + return messages + + async def agenerate_chat_response( q: str, chat_history: List[ChatMessageModel], @@ -1645,33 +1747,39 @@ async def agenerate_chat_response( chat_model = vision_enabled_config vision_available = True + # Build shared conversation context and generate chatml messages + messages = build_conversation_context( + user_query=query_to_run, + references=compiled_references, + online_results=online_results, + code_results=code_results, + operator_results=operator_results, + query_files=query_files, + query_images=query_images, + generated_asset_results=generated_asset_results, + program_execution_context=program_execution_context, + chat_history=chat_history, + location_data=location_data, + user_name=user_name, + agent=agent, + model_type=chat_model.model_type, + model_name=chat_model.name, + max_prompt_size=max_prompt_size, + tokenizer_name=chat_model.tokenizer, + vision_available=vision_available, + ) + if chat_model.model_type == ChatModel.ModelType.OPENAI: openai_chat_config = chat_model.ai_model_api api_key = openai_chat_config.api_key chat_model_name = chat_model.name chat_response_generator = converse_openai( - # Query - query_to_run, - # Context - references=compiled_references, - online_results=online_results, - code_results=code_results, - operator_results=operator_results, - query_images=query_images, - query_files=query_files, - generated_asset_results=generated_asset_results, - program_execution_context=program_execution_context, - location_data=location_data, - user_name=user_name, - chat_history=chat_history, + # Query + Context Messages + messages, # Model model=chat_model_name, api_key=api_key, api_base_url=openai_chat_config.api_base_url, - max_prompt_size=max_prompt_size, - tokenizer_name=chat_model.tokenizer, - agent=agent, - vision_available=vision_available, deepthought=deepthought, tracer=tracer, ) @@ -1680,28 +1788,12 @@ async def agenerate_chat_response( api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url chat_response_generator = converse_anthropic( - # Query - query_to_run, - # Context - references=compiled_references, - online_results=online_results, - code_results=code_results, - operator_results=operator_results, - query_images=query_images, - query_files=query_files, - generated_asset_results=generated_asset_results, - program_execution_context=program_execution_context, - location_data=location_data, - user_name=user_name, - chat_history=chat_history, + # Query + Context Messages + messages, # Model model=chat_model.name, api_key=api_key, api_base_url=api_base_url, - max_prompt_size=max_prompt_size, - tokenizer_name=chat_model.tokenizer, - agent=agent, - vision_available=vision_available, deepthought=deepthought, tracer=tracer, ) @@ -1709,28 +1801,12 @@ async def agenerate_chat_response( api_key = chat_model.ai_model_api.api_key api_base_url = chat_model.ai_model_api.api_base_url chat_response_generator = converse_gemini( - # Query - query_to_run, - # Context - references=compiled_references, - online_results=online_results, - code_results=code_results, - operator_results=operator_results, - query_images=query_images, - query_files=query_files, - generated_asset_results=generated_asset_results, - program_execution_context=program_execution_context, - location_data=location_data, - user_name=user_name, - chat_history=chat_history, + # Query + Context Messages + messages, # Model model=chat_model.name, api_key=api_key, api_base_url=api_base_url, - max_prompt_size=max_prompt_size, - tokenizer_name=chat_model.tokenizer, - agent=agent, - vision_available=vision_available, deepthought=deepthought, tracer=tracer, )