mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Pass args to context builder funcs grouped consistently
Put context params together, followed by model params Use consistent ordering to improve readability
This commit is contained in:
@@ -556,19 +556,21 @@ def gather_raw_query_files(
|
|||||||
|
|
||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
|
# Context
|
||||||
user_message: str,
|
user_message: str,
|
||||||
system_message: str = None,
|
|
||||||
chat_history: list[ChatMessageModel] = [],
|
|
||||||
model_name="gpt-4o-mini",
|
|
||||||
max_prompt_size=None,
|
|
||||||
tokenizer_name=None,
|
|
||||||
query_images=None,
|
|
||||||
vision_enabled=False,
|
|
||||||
model_type="",
|
|
||||||
context_message="",
|
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
|
query_images=None,
|
||||||
|
context_message="",
|
||||||
generated_asset_results: Dict[str, Dict] = {},
|
generated_asset_results: Dict[str, Dict] = {},
|
||||||
program_execution_context: List[str] = [],
|
program_execution_context: List[str] = [],
|
||||||
|
chat_history: list[ChatMessageModel] = [],
|
||||||
|
system_message: str = None,
|
||||||
|
# Model Config
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
model_type="",
|
||||||
|
max_prompt_size=None,
|
||||||
|
tokenizer_name=None,
|
||||||
|
vision_enabled=False,
|
||||||
):
|
):
|
||||||
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
||||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||||
|
|||||||
@@ -156,11 +156,11 @@ async def generate_python_code(
|
|||||||
|
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
code_generation_prompt,
|
code_generation_prompt,
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
user=user,
|
query_images=query_images,
|
||||||
agent_chat_model=agent_chat_model,
|
|
||||||
fast_model=False,
|
fast_model=False,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -286,7 +286,7 @@ async def acreate_title_from_history(
|
|||||||
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
|
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
|
||||||
|
|
||||||
with timer("Chat actor: Generate title from conversation history", logger):
|
with timer("Chat actor: Generate title from conversation history", logger):
|
||||||
response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True)
|
response = await send_message_to_model_wrapper(title_generation_prompt, fast_model=True, user=user)
|
||||||
|
|
||||||
return response.text.strip()
|
return response.text.strip()
|
||||||
|
|
||||||
@@ -298,7 +298,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
|||||||
title_generation_prompt = prompts.subject_generation.format(query=query)
|
title_generation_prompt = prompts.subject_generation.format(query=query)
|
||||||
|
|
||||||
with timer("Chat actor: Generate title from query", logger):
|
with timer("Chat actor: Generate title from query", logger):
|
||||||
response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True)
|
response = await send_message_to_model_wrapper(title_generation_prompt, fast_model=True, user=user)
|
||||||
|
|
||||||
return response.text.strip()
|
return response.text.strip()
|
||||||
|
|
||||||
@@ -321,7 +321,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax:
|
|||||||
|
|
||||||
with timer("Chat actor: Check if safe prompt", logger):
|
with timer("Chat actor: Check if safe prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck, fast_model=True
|
safe_prompt_check, response_type="json_object", response_schema=SafetyCheck, fast_model=True, user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
response = response.text.strip()
|
response = response.text.strip()
|
||||||
@@ -405,12 +405,12 @@ async def aget_data_sources_and_output_format(
|
|||||||
with timer("Chat actor: Infer information sources to refer", logger):
|
with timer("Chat actor: Infer information sources to refer", logger):
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
relevant_tools_prompt,
|
relevant_tools_prompt,
|
||||||
|
query_files=query_files,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
response_schema=PickTools,
|
response_schema=PickTools,
|
||||||
user=user,
|
|
||||||
query_files=query_files,
|
|
||||||
agent_chat_model=agent_chat_model,
|
|
||||||
fast_model=False,
|
fast_model=False,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -494,13 +494,13 @@ async def infer_webpage_urls(
|
|||||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt,
|
online_queries_prompt,
|
||||||
|
query_files=query_files,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
response_schema=WebpageUrls,
|
response_schema=WebpageUrls,
|
||||||
user=user,
|
|
||||||
query_files=query_files,
|
|
||||||
agent_chat_model=agent_chat_model,
|
|
||||||
fast_model=False,
|
fast_model=False,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -560,13 +560,13 @@ async def generate_online_subqueries(
|
|||||||
with timer("Chat actor: Generate online search subqueries", logger):
|
with timer("Chat actor: Generate online search subqueries", logger):
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt,
|
online_queries_prompt,
|
||||||
|
query_files=query_files,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
response_schema=OnlineQueries,
|
response_schema=OnlineQueries,
|
||||||
user=user,
|
|
||||||
query_files=query_files,
|
|
||||||
agent_chat_model=agent_chat_model,
|
|
||||||
fast_model=False,
|
fast_model=False,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -671,10 +671,10 @@ async def extract_relevant_info(
|
|||||||
|
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_information,
|
system_message=prompts.system_prompt_extract_relevant_information,
|
||||||
user=user,
|
|
||||||
agent_chat_model=agent_chat_model,
|
|
||||||
fast_model=True,
|
fast_model=True,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return response.text.strip()
|
return response.text.strip()
|
||||||
@@ -714,11 +714,11 @@ async def extract_relevant_summary(
|
|||||||
with timer("Chat actor: Extract relevant information from data", logger):
|
with timer("Chat actor: Extract relevant information from data", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_summary,
|
|
||||||
user=user,
|
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
agent_chat_model=agent_chat_model,
|
system_message=prompts.system_prompt_extract_relevant_summary,
|
||||||
fast_model=True,
|
fast_model=True,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return response.text.strip()
|
return response.text.strip()
|
||||||
@@ -887,10 +887,10 @@ async def generate_better_diagram_description(
|
|||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
improve_diagram_description_prompt,
|
improve_diagram_description_prompt,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
user=user,
|
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
agent_chat_model=agent_chat_model,
|
|
||||||
fast_model=False,
|
fast_model=False,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
response = response.text.strip()
|
response = response.text.strip()
|
||||||
@@ -920,9 +920,9 @@ async def generate_excalidraw_diagram_from_description(
|
|||||||
with timer("Chat actor: Generate excalidraw diagram", logger):
|
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
query=excalidraw_diagram_generation,
|
query=excalidraw_diagram_generation,
|
||||||
user=user,
|
|
||||||
agent_chat_model=agent_chat_model,
|
|
||||||
fast_model=False,
|
fast_model=False,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
raw_response_text = clean_json(raw_response.text)
|
raw_response_text = clean_json(raw_response.text)
|
||||||
@@ -1042,11 +1042,11 @@ async def generate_better_mermaidjs_diagram_description(
|
|||||||
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
|
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
improve_diagram_description_prompt,
|
improve_diagram_description_prompt,
|
||||||
query_images=query_images,
|
|
||||||
user=user,
|
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
agent_chat_model=agent_chat_model,
|
query_images=query_images,
|
||||||
fast_model=False,
|
fast_model=False,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
response_text = response.text.strip()
|
response_text = response.text.strip()
|
||||||
@@ -1076,9 +1076,9 @@ async def generate_mermaidjs_diagram_from_description(
|
|||||||
with timer("Chat actor: Generate Mermaid.js diagram", logger):
|
with timer("Chat actor: Generate Mermaid.js diagram", logger):
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
query=mermaidjs_diagram_generation,
|
query=mermaidjs_diagram_generation,
|
||||||
user=user,
|
|
||||||
agent_chat_model=agent_chat_model,
|
|
||||||
fast_model=False,
|
fast_model=False,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return clean_mermaidjs(raw_response.text.strip())
|
return clean_mermaidjs(raw_response.text.strip())
|
||||||
@@ -1135,15 +1135,15 @@ async def generate_better_image_prompt(
|
|||||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
q,
|
q,
|
||||||
system_message=enhance_image_system_message,
|
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
query_images=query_images,
|
||||||
chat_history=conversation_history,
|
chat_history=conversation_history,
|
||||||
agent_chat_model=agent_chat_model,
|
system_message=enhance_image_system_message,
|
||||||
fast_model=False,
|
|
||||||
user=user,
|
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
response_schema=ImagePromptResponse,
|
response_schema=ImagePromptResponse,
|
||||||
|
fast_model=False,
|
||||||
|
agent_chat_model=agent_chat_model,
|
||||||
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1216,11 +1216,11 @@ async def search_documents(
|
|||||||
inferred_queries = await extract_questions(
|
inferred_queries = await extract_questions(
|
||||||
query=defiltered_query,
|
query=defiltered_query,
|
||||||
user=user,
|
user=user,
|
||||||
personality_context=personality_context,
|
|
||||||
chat_history=chat_history,
|
|
||||||
location_data=location_data,
|
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
query_images=query_images,
|
||||||
|
personality_context=personality_context,
|
||||||
|
location_data=location_data,
|
||||||
|
chat_history=chat_history,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1266,11 +1266,11 @@ async def search_documents(
|
|||||||
async def extract_questions(
|
async def extract_questions(
|
||||||
query: str,
|
query: str,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
personality_context: str = "",
|
|
||||||
chat_history: List[ChatMessageModel] = [],
|
|
||||||
location_data: LocationData = None,
|
|
||||||
query_images: Optional[List[str]] = None,
|
|
||||||
query_files: str = None,
|
query_files: str = None,
|
||||||
|
query_images: Optional[List[str]] = None,
|
||||||
|
personality_context: str = "",
|
||||||
|
location_data: LocationData = None,
|
||||||
|
chat_history: List[ChatMessageModel] = [],
|
||||||
max_queries: int = 5,
|
max_queries: int = 5,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
@@ -1317,10 +1317,10 @@ async def extract_questions(
|
|||||||
)
|
)
|
||||||
|
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
system_message=system_prompt,
|
|
||||||
query=prompt,
|
query=prompt,
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
query_files=query_files,
|
||||||
|
query_images=query_images,
|
||||||
|
system_message=system_prompt,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
response_schema=DocumentQueries,
|
response_schema=DocumentQueries,
|
||||||
fast_model=False,
|
fast_model=False,
|
||||||
@@ -1436,19 +1436,23 @@ async def execute_search(
|
|||||||
|
|
||||||
|
|
||||||
async def send_message_to_model_wrapper(
|
async def send_message_to_model_wrapper(
|
||||||
|
# Context
|
||||||
query: str,
|
query: str,
|
||||||
|
query_files: str = None,
|
||||||
|
query_images: List[str] = None,
|
||||||
|
context: str = "",
|
||||||
|
chat_history: list[ChatMessageModel] = [],
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
|
# Model Config
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
response_schema: BaseModel = None,
|
response_schema: BaseModel = None,
|
||||||
tools: List[ToolDefinition] = None,
|
tools: List[ToolDefinition] = None,
|
||||||
deepthought: bool = False,
|
deepthought: bool = False,
|
||||||
fast_model: Optional[bool] = None,
|
fast_model: Optional[bool] = None,
|
||||||
user: KhojUser = None,
|
|
||||||
query_images: List[str] = None,
|
|
||||||
context: str = "",
|
|
||||||
query_files: str = None,
|
|
||||||
chat_history: list[ChatMessageModel] = [],
|
|
||||||
agent_chat_model: ChatModel = None,
|
agent_chat_model: ChatModel = None,
|
||||||
|
# User
|
||||||
|
user: KhojUser = None,
|
||||||
|
# Tracer
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model, fast=fast_model)
|
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model, fast=fast_model)
|
||||||
@@ -1472,16 +1476,16 @@ async def send_message_to_model_wrapper(
|
|||||||
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=query,
|
user_message=query,
|
||||||
|
query_files=query_files,
|
||||||
|
query_images=query_images,
|
||||||
context_message=context,
|
context_message=context,
|
||||||
system_message=system_message,
|
|
||||||
chat_history=chat_history,
|
chat_history=chat_history,
|
||||||
|
system_message=system_message,
|
||||||
model_name=chat_model_name,
|
model_name=chat_model_name,
|
||||||
|
model_type=model_type,
|
||||||
tokenizer_name=tokenizer,
|
tokenizer_name=tokenizer,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
query_images=query_images,
|
|
||||||
model_type=model_type,
|
|
||||||
query_files=query_files,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == ChatModel.ModelType.OPENAI:
|
if model_type == ChatModel.ModelType.OPENAI:
|
||||||
@@ -1549,14 +1553,14 @@ def send_message_to_model_wrapper_sync(
|
|||||||
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message,
|
user_message=message,
|
||||||
system_message=system_message,
|
query_files=query_files,
|
||||||
|
query_images=query_images,
|
||||||
chat_history=chat_history,
|
chat_history=chat_history,
|
||||||
|
system_message=system_message,
|
||||||
model_name=chat_model_name,
|
model_name=chat_model_name,
|
||||||
|
model_type=model_type,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=model_type,
|
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == ChatModel.ModelType.OPENAI:
|
if model_type == ChatModel.ModelType.OPENAI:
|
||||||
|
|||||||
@@ -159,15 +159,15 @@ async def apick_next_tool(
|
|||||||
with timer("Chat actor: Infer information sources to refer", logger):
|
with timer("Chat actor: Infer information sources to refer", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
query="",
|
query="",
|
||||||
|
query_files=query_files,
|
||||||
|
query_images=query_images,
|
||||||
system_message=function_planning_prompt,
|
system_message=function_planning_prompt,
|
||||||
chat_history=chat_and_research_history,
|
chat_history=chat_and_research_history,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
deepthought=True,
|
deepthought=True,
|
||||||
fast_model=False,
|
fast_model=False,
|
||||||
user=user,
|
|
||||||
query_images=query_images,
|
|
||||||
query_files=query_files,
|
|
||||||
agent_chat_model=agent_chat_model,
|
agent_chat_model=agent_chat_model,
|
||||||
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user