Pass args to context builder funcs grouped consistently

Put context params together, followed by model params
Use consistent ordering to improve readability
This commit is contained in:
Debanjum
2025-08-27 13:30:01 -07:00
parent 4976b244a4
commit 02e220f5f5
4 changed files with 76 additions and 70 deletions

View File

@@ -556,19 +556,21 @@ def gather_raw_query_files(
def generate_chatml_messages_with_context(
# Context
user_message: str,
system_message: str = None,
chat_history: list[ChatMessageModel] = [],
model_name="gpt-4o-mini",
max_prompt_size=None,
tokenizer_name=None,
query_images=None,
vision_enabled=False,
model_type="",
context_message="",
query_files: str = None,
query_images=None,
context_message="",
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = [],
chat_history: list[ChatMessageModel] = [],
system_message: str = None,
# Model Config
model_name="gpt-4o-mini",
model_type="",
max_prompt_size=None,
tokenizer_name=None,
vision_enabled=False,
):
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
# Set max prompt size from user config or based on pre-configured for model and machine specs

View File

@@ -156,11 +156,11 @@ async def generate_python_code(
response = await send_message_to_model_wrapper(
code_generation_prompt,
query_images=query_images,
query_files=query_files,
user=user,
agent_chat_model=agent_chat_model,
query_images=query_images,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
)

View File

@@ -286,7 +286,7 @@ async def acreate_title_from_history(
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
with timer("Chat actor: Generate title from conversation history", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True)
response = await send_message_to_model_wrapper(title_generation_prompt, fast_model=True, user=user)
return response.text.strip()
@@ -298,7 +298,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
title_generation_prompt = prompts.subject_generation.format(query=query)
with timer("Chat actor: Generate title from query", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True)
response = await send_message_to_model_wrapper(title_generation_prompt, fast_model=True, user=user)
return response.text.strip()
@@ -321,7 +321,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
with timer("Chat actor: Check if safe prompt", logger):
response = await send_message_to_model_wrapper(
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck, fast_model=True
safe_prompt_check, response_type="json_object", response_schema=SafetyCheck, fast_model=True, user=user
)
response = response.text.strip()
@@ -405,12 +405,12 @@ async def aget_data_sources_and_output_format(
with timer("Chat actor: Infer information sources to refer", logger):
raw_response = await send_message_to_model_wrapper(
relevant_tools_prompt,
query_files=query_files,
response_type="json_object",
response_schema=PickTools,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
)
@@ -494,13 +494,13 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger):
raw_response = await send_message_to_model_wrapper(
online_queries_prompt,
query_files=query_files,
query_images=query_images,
response_type="json_object",
response_schema=WebpageUrls,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
)
@@ -560,13 +560,13 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger):
raw_response = await send_message_to_model_wrapper(
online_queries_prompt,
query_files=query_files,
query_images=query_images,
response_type="json_object",
response_schema=OnlineQueries,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
)
@@ -671,10 +671,10 @@ async def extract_relevant_info(
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
user=user,
agent_chat_model=agent_chat_model,
system_message=prompts.system_prompt_extract_relevant_information,
fast_model=True,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
)
return response.text.strip()
@@ -714,11 +714,11 @@ async def extract_relevant_summary(
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_summary,
user=user,
query_images=query_images,
agent_chat_model=agent_chat_model,
system_message=prompts.system_prompt_extract_relevant_summary,
fast_model=True,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
)
return response.text.strip()
@@ -887,10 +887,10 @@ async def generate_better_diagram_description(
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt,
query_images=query_images,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
)
response = response.text.strip()
@@ -920,9 +920,9 @@ async def generate_excalidraw_diagram_from_description(
with timer("Chat actor: Generate excalidraw diagram", logger):
raw_response = await send_message_to_model_wrapper(
query=excalidraw_diagram_generation,
user=user,
agent_chat_model=agent_chat_model,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
)
raw_response_text = clean_json(raw_response.text)
@@ -1042,11 +1042,11 @@ async def generate_better_mermaidjs_diagram_description(
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt,
query_images=query_images,
user=user,
query_files=query_files,
agent_chat_model=agent_chat_model,
query_images=query_images,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
)
response_text = response.text.strip()
@@ -1076,9 +1076,9 @@ async def generate_mermaidjs_diagram_from_description(
with timer("Chat actor: Generate Mermaid.js diagram", logger):
raw_response = await send_message_to_model_wrapper(
query=mermaidjs_diagram_generation,
user=user,
agent_chat_model=agent_chat_model,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
)
return clean_mermaidjs(raw_response.text.strip())
@@ -1135,15 +1135,15 @@ async def generate_better_image_prompt(
with timer("Chat actor: Generate contextual image prompt", logger):
raw_response = await send_message_to_model_wrapper(
q,
system_message=enhance_image_system_message,
query_images=query_images,
query_files=query_files,
query_images=query_images,
chat_history=conversation_history,
agent_chat_model=agent_chat_model,
fast_model=False,
user=user,
system_message=enhance_image_system_message,
response_type="json_object",
response_schema=ImagePromptResponse,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
)
@@ -1216,11 +1216,11 @@ async def search_documents(
inferred_queries = await extract_questions(
query=defiltered_query,
user=user,
personality_context=personality_context,
chat_history=chat_history,
location_data=location_data,
query_images=query_images,
query_files=query_files,
query_images=query_images,
personality_context=personality_context,
location_data=location_data,
chat_history=chat_history,
tracer=tracer,
)
@@ -1266,11 +1266,11 @@ async def search_documents(
async def extract_questions(
query: str,
user: KhojUser,
personality_context: str = "",
chat_history: List[ChatMessageModel] = [],
location_data: LocationData = None,
query_images: Optional[List[str]] = None,
query_files: str = None,
query_images: Optional[List[str]] = None,
personality_context: str = "",
location_data: LocationData = None,
chat_history: List[ChatMessageModel] = [],
max_queries: int = 5,
tracer: dict = {},
):
@@ -1317,10 +1317,10 @@ async def extract_questions(
)
raw_response = await send_message_to_model_wrapper(
system_message=system_prompt,
query=prompt,
query_images=query_images,
query_files=query_files,
query_images=query_images,
system_message=system_prompt,
response_type="json_object",
response_schema=DocumentQueries,
fast_model=False,
@@ -1436,19 +1436,23 @@ async def execute_search(
async def send_message_to_model_wrapper(
# Context
query: str,
query_files: str = None,
query_images: List[str] = None,
context: str = "",
chat_history: list[ChatMessageModel] = [],
system_message: str = "",
# Model Config
response_type: str = "text",
response_schema: BaseModel = None,
tools: List[ToolDefinition] = None,
deepthought: bool = False,
fast_model: Optional[bool] = None,
user: KhojUser = None,
query_images: List[str] = None,
context: str = "",
query_files: str = None,
chat_history: list[ChatMessageModel] = [],
agent_chat_model: ChatModel = None,
# User
user: KhojUser = None,
# Tracer
tracer: dict = {},
):
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model, fast=fast_model)
@@ -1472,16 +1476,16 @@ async def send_message_to_model_wrapper(
truncated_messages = generate_chatml_messages_with_context(
user_message=query,
query_files=query_files,
query_images=query_images,
context_message=context,
system_message=system_message,
chat_history=chat_history,
system_message=system_message,
model_name=chat_model_name,
model_type=model_type,
tokenizer_name=tokenizer,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
query_images=query_images,
model_type=model_type,
query_files=query_files,
)
if model_type == ChatModel.ModelType.OPENAI:
@@ -1549,14 +1553,14 @@ def send_message_to_model_wrapper_sync(
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
system_message=system_message,
query_files=query_files,
query_images=query_images,
chat_history=chat_history,
system_message=system_message,
model_name=chat_model_name,
model_type=model_type,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=model_type,
query_images=query_images,
query_files=query_files,
)
if model_type == ChatModel.ModelType.OPENAI:

View File

@@ -159,15 +159,15 @@ async def apick_next_tool(
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
query="",
query_files=query_files,
query_images=query_images,
system_message=function_planning_prompt,
chat_history=chat_and_research_history,
tools=tools,
deepthought=True,
fast_model=False,
user=user,
query_images=query_images,
query_files=query_files,
agent_chat_model=agent_chat_model,
user=user,
tracer=tracer,
)
except Exception as e: