Share context builder for chat final response across model types

The context building logic was nearly identical across all model
types.

This change extracts that logic into a shared function and calls it
once in the `agenerate_chat_response', the entrypoint to the converse
methods for all 3 model types.

Main differences handled are
- Gemini system prompt had additional verbosity instructions. Keep it
- Pass system messsage via chatml messages list to anthropic, gemini
  models as well (like openai models) instead of passing it as
  separate arg to chat_completion_* funcs.

  The model specific message formatters for both already extract
  system instruction from the messages list. So system messages wil be
  automatically extracted from the chat_completion_* funcs to pass as
  separate arg required by anthropic, gemini api libraries.
This commit is contained in:
Debanjum
2025-08-26 23:32:17 -07:00
parent 02e220f5f5
commit 7a42042488
6 changed files with 145 additions and 313 deletions

View File

@@ -1,22 +1,16 @@
import logging
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
from typing import AsyncGenerator, List, Optional
from langchain_core.messages.chat import ChatMessage
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
anthropic_completion_with_backoff,
)
from khoj.processor.conversation.utils import (
OperatorRun,
ResponseWithThought,
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
from khoj.utils.rawconfig import LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@@ -52,91 +46,17 @@ def anthropic_send_message_to_model(
async def converse_anthropic(
# Query
user_query: str,
# Context
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[List[OperatorRun]] = None,
query_images: Optional[list[str]] = None,
query_files: str = None,
program_execution_context: Optional[List[str]] = None,
generated_asset_results: Dict[str, Dict] = {},
location_data: LocationData = None,
user_name: str = None,
chat_history: List[ChatMessageModel] = [],
messages: List[ChatMessage],
# Model
model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
max_prompt_size=None,
tokenizer_name=None,
agent: Agent = None,
vision_available: bool = False,
deepthought: Optional[bool] = False,
tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]:
"""
Converse with user using Anthropic's Claude
"""
# Initialize Variables
current_date = datetime.now()
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
name=agent.name,
bio=agent.personality,
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
else:
system_prompt = prompts.personality.format(
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
if location_data:
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt = f"{system_prompt}\n{user_name_prompt}"
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
if not is_none_or_empty(online_results):
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
if not is_none_or_empty(code_results):
context_message += (
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if not is_none_or_empty(operator_results):
operator_content = [
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
]
context_message += (
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
user_query,
context_message=context_message,
chat_history=chat_history,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModel.ModelType.ANTHROPIC,
query_files=query_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
)
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
# Get Response from Claude
@@ -146,8 +66,6 @@ async def converse_anthropic(
temperature=0.2,
api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt,
max_prompt_size=max_prompt_size,
deepthought=deepthought,
tracer=tracer,
):

View File

@@ -184,8 +184,7 @@ async def anthropic_chat_completion_with_backoff(
temperature: float,
api_key: str | None,
api_base_url: str,
system_prompt: str,
max_prompt_size: int | None = None,
system_prompt: str = "",
deepthought: bool = False,
model_kwargs: dict | None = None,
tracer: dict = {},

View File

@@ -1,22 +1,16 @@
import logging
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional
from typing import AsyncGenerator, List, Optional
from langchain_core.messages.chat import ChatMessage
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
gemini_chat_completion_with_backoff,
gemini_completion_with_backoff,
)
from khoj.processor.conversation.utils import (
OperatorRun,
ResponseWithThought,
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils.helpers import is_none_or_empty, truncate_code_context
from khoj.utils.rawconfig import LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@@ -61,93 +55,18 @@ def gemini_send_message_to_model(
async def converse_gemini(
# Query
user_query: str,
# Context
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[List[OperatorRun]] = None,
query_images: Optional[list[str]] = None,
query_files: str = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
location_data: LocationData = None,
user_name: str = None,
chat_history: List[ChatMessageModel] = [],
messages: List[ChatMessage],
# Model
model: Optional[str] = "gemini-2.5-flash",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
temperature: float = 1.0,
max_prompt_size=None,
tokenizer_name=None,
agent: Agent = None,
vision_available: bool = False,
deepthought: Optional[bool] = False,
tracer={},
) -> AsyncGenerator[ResponseWithThought, None]:
"""
Converse with user using Google's Gemini
"""
# Initialize Variables
current_date = datetime.now()
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
name=agent.name,
bio=agent.personality,
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
else:
system_prompt = prompts.personality.format(
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
system_prompt += f"\n\n{prompts.gemini_verbose_language_personality}"
if location_data:
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt += f"\n{location_prompt}"
if user_name:
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt += f"\n{user_name_prompt}"
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
if not is_none_or_empty(online_results):
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
if not is_none_or_empty(code_results):
context_message += (
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if not is_none_or_empty(operator_results):
operator_content = [
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
]
context_message += (
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
user_query,
context_message=context_message,
chat_history=chat_history,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModel.ModelType.GOOGLE,
query_files=query_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
)
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
# Get Response from Google AI
@@ -157,7 +76,6 @@ async def converse_gemini(
temperature=temperature,
api_key=api_key,
api_base_url=api_base_url,
system_prompt=system_prompt,
deepthought=deepthought,
tracer=tracer,
):

View File

@@ -308,7 +308,7 @@ async def gemini_chat_completion_with_backoff(
temperature: float,
api_key: str,
api_base_url: str,
system_prompt: str,
system_prompt: str = "",
model_kwargs=None,
deepthought=False,
tracer: dict = {},

View File

@@ -1,9 +1,8 @@
import logging
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional
from khoj.database.models import Agent, ChatMessageModel, ChatModel
from khoj.processor.conversation import prompts
from langchain_core.messages.chat import ChatMessage
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
clean_response_schema,
@@ -15,15 +14,11 @@ from khoj.processor.conversation.openai.utils import (
to_openai_tools,
)
from khoj.processor.conversation.utils import (
OperatorRun,
ResponseWithThought,
StructuredOutputSupport,
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils.helpers import ToolDefinition, is_none_or_empty, truncate_code_context
from khoj.utils.rawconfig import LocationData
from khoj.utils.yaml import yaml_dump
from khoj.utils.helpers import ToolDefinition
logger = logging.getLogger(__name__)
@@ -96,92 +91,18 @@ def send_message_to_model(
async def converse_openai(
# Query
user_query: str,
# Context
references: list[dict],
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[List[OperatorRun]] = None,
query_images: Optional[list[str]] = None,
query_files: str = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
location_data: LocationData = None,
chat_history: list[ChatMessageModel] = [],
messages: List[ChatMessage],
# Model
model: str = "gpt-4.1-mini",
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
temperature: float = 0.6,
max_prompt_size=None,
tokenizer_name=None,
user_name: str = None,
agent: Agent = None,
vision_available: bool = False,
deepthought: Optional[bool] = False,
tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]:
"""
Converse with user using OpenAI's ChatGPT
"""
# Initialize Variables
current_date = datetime.now()
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
name=agent.name,
bio=agent.personality,
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
else:
system_prompt = prompts.personality.format(
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
if location_data:
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt = f"{system_prompt}\n{location_prompt}"
if user_name:
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt = f"{system_prompt}\n{user_name_prompt}"
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n"
if not is_none_or_empty(online_results):
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
if not is_none_or_empty(code_results):
context_message += (
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if not is_none_or_empty(operator_results):
operator_content = [
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
]
context_message += (
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
user_query,
system_prompt,
chat_history,
context_message=context_message,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModel.ModelType.OPENAI,
query_files=query_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
)
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
# Get Response from GPT

View File

@@ -33,6 +33,7 @@ from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async
from django.utils import timezone as django_timezone
from fastapi import Depends, Header, HTTPException, Request, UploadFile, WebSocket
from langchain_core.messages.chat import ChatMessage
from pydantic import BaseModel, EmailStr, Field
from starlette.authentication import has_required_scope
from starlette.requests import URL
@@ -124,6 +125,7 @@ from khoj.utils.helpers import (
mode_descriptions_for_llm,
timer,
tool_descriptions_for_llm,
truncate_code_context,
)
from khoj.utils.rawconfig import (
ChatRequestBody,
@@ -132,6 +134,7 @@ from khoj.utils.rawconfig import (
SearchResponse,
)
from khoj.utils.state import SearchType
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@@ -1598,6 +1601,105 @@ def send_message_to_model_wrapper_sync(
raise HTTPException(status_code=500, detail="Invalid conversation config")
def build_conversation_context(
# Query and Context
user_query: str,
references: List[Dict],
online_results: Dict[str, Dict],
code_results: Dict[str, Dict],
operator_results: List[OperatorRun],
query_files: str = None,
query_images: Optional[List[str]] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
chat_history: List[ChatMessageModel] = [],
location_data: LocationData = None,
user_name: str = None,
# Model config
agent: Agent = None,
model_name: str = None,
model_type: ChatModel.ModelType = None,
max_prompt_size: int = None,
tokenizer_name: str = None,
vision_available: bool = False,
) -> List[ChatMessage]:
"""
Construct system, context and chatml messages for chat response.
Share common logic across different model types.
Returns:
List of ChatMessages with context
"""
# Initialize Variables
current_date = datetime.now()
# Build system prompt
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
name=agent.name,
bio=agent.personality,
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
else:
system_prompt = prompts.personality.format(
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
# Add Gemini-specific personality enhancement
if model_type == ChatModel.ModelType.GOOGLE:
system_prompt += f"\n\n{prompts.gemini_verbose_language_personality}"
# Add location context if available
if location_data:
location_prompt = prompts.user_location.format(location=f"{location_data}")
system_prompt += f"\n{location_prompt}"
# Add user name context if available
if user_name:
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt += f"\n{user_name_prompt}"
# Build context message
context_message = ""
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n"
if not is_none_or_empty(online_results):
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
if not is_none_or_empty(code_results):
context_message += (
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
)
if not is_none_or_empty(operator_results):
operator_content = [
{"query": oc.query, "response": oc.response, "webpages": oc.webpages} for oc in operator_results
]
context_message += (
f"{prompts.operator_execution_context.format(operator_results=yaml_dump(operator_content))}\n\n"
)
context_message = context_message.strip()
# Generate the chatml messages
messages = generate_chatml_messages_with_context(
user_message=user_query,
query_files=query_files,
query_images=query_images,
context_message=context_message,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
chat_history=chat_history,
system_message=system_prompt,
model_name=model_name,
model_type=model_type,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
vision_enabled=vision_available,
)
return messages
async def agenerate_chat_response(
q: str,
chat_history: List[ChatMessageModel],
@@ -1645,33 +1747,39 @@ async def agenerate_chat_response(
chat_model = vision_enabled_config
vision_available = True
# Build shared conversation context and generate chatml messages
messages = build_conversation_context(
user_query=query_to_run,
references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
query_files=query_files,
query_images=query_images,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
chat_history=chat_history,
location_data=location_data,
user_name=user_name,
agent=agent,
model_type=chat_model.model_type,
model_name=chat_model.name,
max_prompt_size=max_prompt_size,
tokenizer_name=chat_model.tokenizer,
vision_available=vision_available,
)
if chat_model.model_type == ChatModel.ModelType.OPENAI:
openai_chat_config = chat_model.ai_model_api
api_key = openai_chat_config.api_key
chat_model_name = chat_model.name
chat_response_generator = converse_openai(
# Query
query_to_run,
# Context
references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
query_images=query_images,
query_files=query_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
location_data=location_data,
user_name=user_name,
chat_history=chat_history,
# Query + Context Messages
messages,
# Model
model=chat_model_name,
api_key=api_key,
api_base_url=openai_chat_config.api_base_url,
max_prompt_size=max_prompt_size,
tokenizer_name=chat_model.tokenizer,
agent=agent,
vision_available=vision_available,
deepthought=deepthought,
tracer=tracer,
)
@@ -1680,28 +1788,12 @@ async def agenerate_chat_response(
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response_generator = converse_anthropic(
# Query
query_to_run,
# Context
references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
query_images=query_images,
query_files=query_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
location_data=location_data,
user_name=user_name,
chat_history=chat_history,
# Query + Context Messages
messages,
# Model
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
max_prompt_size=max_prompt_size,
tokenizer_name=chat_model.tokenizer,
agent=agent,
vision_available=vision_available,
deepthought=deepthought,
tracer=tracer,
)
@@ -1709,28 +1801,12 @@ async def agenerate_chat_response(
api_key = chat_model.ai_model_api.api_key
api_base_url = chat_model.ai_model_api.api_base_url
chat_response_generator = converse_gemini(
# Query
query_to_run,
# Context
references=compiled_references,
online_results=online_results,
code_results=code_results,
operator_results=operator_results,
query_images=query_images,
query_files=query_files,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
location_data=location_data,
user_name=user_name,
chat_history=chat_history,
# Query + Context Messages
messages,
# Model
model=chat_model.name,
api_key=api_key,
api_base_url=api_base_url,
max_prompt_size=max_prompt_size,
tokenizer_name=chat_model.tokenizer,
agent=agent,
vision_available=vision_available,
deepthought=deepthought,
tracer=tracer,
)