mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Pass args to context builder funcs grouped consistently
Put context params together, followed by model params Use consistent ordering to improve readability
This commit is contained in:
@@ -556,19 +556,21 @@ def gather_raw_query_files(
|
||||
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
# Context
|
||||
user_message: str,
|
||||
system_message: str = None,
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
model_name="gpt-4o-mini",
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
query_images=None,
|
||||
vision_enabled=False,
|
||||
model_type="",
|
||||
context_message="",
|
||||
query_files: str = None,
|
||||
query_images=None,
|
||||
context_message="",
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
program_execution_context: List[str] = [],
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
system_message: str = None,
|
||||
# Model Config
|
||||
model_name="gpt-4o-mini",
|
||||
model_type="",
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
vision_enabled=False,
|
||||
):
|
||||
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||
|
||||
@@ -156,11 +156,11 @@ async def generate_python_code(
|
||||
|
||||
response = await send_message_to_model_wrapper(
|
||||
code_generation_prompt,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
user=user,
|
||||
agent_chat_model=agent_chat_model,
|
||||
query_images=query_images,
|
||||
fast_model=False,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
@@ -286,7 +286,7 @@ async def acreate_title_from_history(
|
||||
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
|
||||
|
||||
with timer("Chat actor: Generate title from conversation history", logger):
|
||||
response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True)
|
||||
response = await send_message_to_model_wrapper(title_generation_prompt, fast_model=True, user=user)
|
||||
|
||||
return response.text.strip()
|
||||
|
||||
@@ -298,7 +298,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
||||
title_generation_prompt = prompts.subject_generation.format(query=query)
|
||||
|
||||
with timer("Chat actor: Generate title from query", logger):
|
||||
response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True)
|
||||
response = await send_message_to_model_wrapper(title_generation_prompt, fast_model=True, user=user)
|
||||
|
||||
return response.text.strip()
|
||||
|
||||
@@ -321,7 +321,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
|
||||
|
||||
with timer("Chat actor: Check if safe prompt", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck, fast_model=True
|
||||
safe_prompt_check, response_type="json_object", response_schema=SafetyCheck, fast_model=True, user=user
|
||||
)
|
||||
|
||||
response = response.text.strip()
|
||||
@@ -405,12 +405,12 @@ async def aget_data_sources_and_output_format(
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
relevant_tools_prompt,
|
||||
query_files=query_files,
|
||||
response_type="json_object",
|
||||
response_schema=PickTools,
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -494,13 +494,13 @@ async def infer_webpage_urls(
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt,
|
||||
query_files=query_files,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
response_schema=WebpageUrls,
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -560,13 +560,13 @@ async def generate_online_subqueries(
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt,
|
||||
query_files=query_files,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
response_schema=OnlineQueries,
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -671,10 +671,10 @@ async def extract_relevant_info(
|
||||
|
||||
response = await send_message_to_model_wrapper(
|
||||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_information,
|
||||
user=user,
|
||||
agent_chat_model=agent_chat_model,
|
||||
system_message=prompts.system_prompt_extract_relevant_information,
|
||||
fast_model=True,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
return response.text.strip()
|
||||
@@ -714,11 +714,11 @@ async def extract_relevant_summary(
|
||||
with timer("Chat actor: Extract relevant information from data", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_summary,
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
agent_chat_model=agent_chat_model,
|
||||
system_message=prompts.system_prompt_extract_relevant_summary,
|
||||
fast_model=True,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
return response.text.strip()
|
||||
@@ -887,10 +887,10 @@ async def generate_better_diagram_description(
|
||||
response = await send_message_to_model_wrapper(
|
||||
improve_diagram_description_prompt,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
response = response.text.strip()
|
||||
@@ -920,9 +920,9 @@ async def generate_excalidraw_diagram_from_description(
|
||||
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
query=excalidraw_diagram_generation,
|
||||
user=user,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
raw_response_text = clean_json(raw_response.text)
|
||||
@@ -1042,11 +1042,11 @@ async def generate_better_mermaidjs_diagram_description(
|
||||
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
improve_diagram_description_prompt,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
query_images=query_images,
|
||||
fast_model=False,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
response_text = response.text.strip()
|
||||
@@ -1076,9 +1076,9 @@ async def generate_mermaidjs_diagram_from_description(
|
||||
with timer("Chat actor: Generate Mermaid.js diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
query=mermaidjs_diagram_generation,
|
||||
user=user,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
return clean_mermaidjs(raw_response.text.strip())
|
||||
@@ -1135,15 +1135,15 @@ async def generate_better_image_prompt(
|
||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
q,
|
||||
system_message=enhance_image_system_message,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
query_images=query_images,
|
||||
chat_history=conversation_history,
|
||||
agent_chat_model=agent_chat_model,
|
||||
fast_model=False,
|
||||
user=user,
|
||||
system_message=enhance_image_system_message,
|
||||
response_type="json_object",
|
||||
response_schema=ImagePromptResponse,
|
||||
fast_model=False,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -1216,11 +1216,11 @@ async def search_documents(
|
||||
inferred_queries = await extract_questions(
|
||||
query=defiltered_query,
|
||||
user=user,
|
||||
personality_context=personality_context,
|
||||
chat_history=chat_history,
|
||||
location_data=location_data,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
query_images=query_images,
|
||||
personality_context=personality_context,
|
||||
location_data=location_data,
|
||||
chat_history=chat_history,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
@@ -1266,11 +1266,11 @@ async def search_documents(
|
||||
async def extract_questions(
|
||||
query: str,
|
||||
user: KhojUser,
|
||||
personality_context: str = "",
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
location_data: LocationData = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
query_files: str = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
personality_context: str = "",
|
||||
location_data: LocationData = None,
|
||||
chat_history: List[ChatMessageModel] = [],
|
||||
max_queries: int = 5,
|
||||
tracer: dict = {},
|
||||
):
|
||||
@@ -1317,10 +1317,10 @@ async def extract_questions(
|
||||
)
|
||||
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
system_message=system_prompt,
|
||||
query=prompt,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
query_images=query_images,
|
||||
system_message=system_prompt,
|
||||
response_type="json_object",
|
||||
response_schema=DocumentQueries,
|
||||
fast_model=False,
|
||||
@@ -1436,19 +1436,23 @@ async def execute_search(
|
||||
|
||||
|
||||
async def send_message_to_model_wrapper(
|
||||
# Context
|
||||
query: str,
|
||||
query_files: str = None,
|
||||
query_images: List[str] = None,
|
||||
context: str = "",
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
system_message: str = "",
|
||||
# Model Config
|
||||
response_type: str = "text",
|
||||
response_schema: BaseModel = None,
|
||||
tools: List[ToolDefinition] = None,
|
||||
deepthought: bool = False,
|
||||
fast_model: Optional[bool] = None,
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
context: str = "",
|
||||
query_files: str = None,
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
agent_chat_model: ChatModel = None,
|
||||
# User
|
||||
user: KhojUser = None,
|
||||
# Tracer
|
||||
tracer: dict = {},
|
||||
):
|
||||
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model, fast=fast_model)
|
||||
@@ -1472,16 +1476,16 @@ async def send_message_to_model_wrapper(
|
||||
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=query,
|
||||
query_files=query_files,
|
||||
query_images=query_images,
|
||||
context_message=context,
|
||||
system_message=system_message,
|
||||
chat_history=chat_history,
|
||||
system_message=system_message,
|
||||
model_name=chat_model_name,
|
||||
model_type=model_type,
|
||||
tokenizer_name=tokenizer,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
query_images=query_images,
|
||||
model_type=model_type,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
if model_type == ChatModel.ModelType.OPENAI:
|
||||
@@ -1549,14 +1553,14 @@ def send_message_to_model_wrapper_sync(
|
||||
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
query_files=query_files,
|
||||
query_images=query_images,
|
||||
chat_history=chat_history,
|
||||
system_message=system_message,
|
||||
model_name=chat_model_name,
|
||||
model_type=model_type,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=model_type,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
if model_type == ChatModel.ModelType.OPENAI:
|
||||
|
||||
@@ -159,15 +159,15 @@ async def apick_next_tool(
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
query="",
|
||||
query_files=query_files,
|
||||
query_images=query_images,
|
||||
system_message=function_planning_prompt,
|
||||
chat_history=chat_and_research_history,
|
||||
tools=tools,
|
||||
deepthought=True,
|
||||
fast_model=False,
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
agent_chat_model=agent_chat_model,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user