diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index f2e9cb3a..46a03b2c 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -556,19 +556,21 @@ def gather_raw_query_files( def generate_chatml_messages_with_context( + # Context user_message: str, - system_message: str = None, - chat_history: list[ChatMessageModel] = [], - model_name="gpt-4o-mini", - max_prompt_size=None, - tokenizer_name=None, - query_images=None, - vision_enabled=False, - model_type="", - context_message="", query_files: str = None, + query_images=None, + context_message="", generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = [], + chat_history: list[ChatMessageModel] = [], + system_message: str = None, + # Model Config + model_name="gpt-4o-mini", + model_type="", + max_prompt_size=None, + tokenizer_name=None, + vision_enabled=False, ): """Generate chat messages with appropriate context from previous conversation to send to the chat model""" # Set max prompt size from user config or based on pre-configured for model and machine specs diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 59702e36..26edad27 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -156,11 +156,11 @@ async def generate_python_code( response = await send_message_to_model_wrapper( code_generation_prompt, - query_images=query_images, query_files=query_files, - user=user, - agent_chat_model=agent_chat_model, + query_images=query_images, fast_model=False, + agent_chat_model=agent_chat_model, + user=user, tracer=tracer, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 13d201d3..80f9fbdd 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -286,7 +286,7 @@ async def acreate_title_from_history( title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history) with timer("Chat actor: Generate title from conversation history", logger): - response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True) + response = await send_message_to_model_wrapper(title_generation_prompt, fast_model=True, user=user) return response.text.strip() @@ -298,7 +298,7 @@ async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: title_generation_prompt = prompts.subject_generation.format(query=query) with timer("Chat actor: Generate title from query", logger): - response = await send_message_to_model_wrapper(title_generation_prompt, user=user, fast_model=True) + response = await send_message_to_model_wrapper(title_generation_prompt, fast_model=True, user=user) return response.text.strip() @@ -321,7 +321,7 @@ async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None, lax: with timer("Chat actor: Check if safe prompt", logger): response = await send_message_to_model_wrapper( - safe_prompt_check, user=user, response_type="json_object", response_schema=SafetyCheck, fast_model=True + safe_prompt_check, response_type="json_object", response_schema=SafetyCheck, fast_model=True, user=user ) response = response.text.strip() @@ -405,12 +405,12 @@ async def aget_data_sources_and_output_format( with timer("Chat actor: Infer information sources to refer", logger): raw_response = await send_message_to_model_wrapper( relevant_tools_prompt, + query_files=query_files, response_type="json_object", response_schema=PickTools, - user=user, - query_files=query_files, - agent_chat_model=agent_chat_model, fast_model=False, + agent_chat_model=agent_chat_model, + user=user, tracer=tracer, ) @@ -494,13 +494,13 @@ async def infer_webpage_urls( with timer("Chat actor: Infer webpage urls to read", logger): raw_response = await send_message_to_model_wrapper( online_queries_prompt, + query_files=query_files, query_images=query_images, response_type="json_object", response_schema=WebpageUrls, - user=user, - query_files=query_files, - agent_chat_model=agent_chat_model, fast_model=False, + agent_chat_model=agent_chat_model, + user=user, tracer=tracer, ) @@ -560,13 +560,13 @@ async def generate_online_subqueries( with timer("Chat actor: Generate online search subqueries", logger): raw_response = await send_message_to_model_wrapper( online_queries_prompt, + query_files=query_files, query_images=query_images, response_type="json_object", response_schema=OnlineQueries, - user=user, - query_files=query_files, - agent_chat_model=agent_chat_model, fast_model=False, + agent_chat_model=agent_chat_model, + user=user, tracer=tracer, ) @@ -671,10 +671,10 @@ async def extract_relevant_info( response = await send_message_to_model_wrapper( extract_relevant_information, - prompts.system_prompt_extract_relevant_information, - user=user, - agent_chat_model=agent_chat_model, + system_message=prompts.system_prompt_extract_relevant_information, fast_model=True, + agent_chat_model=agent_chat_model, + user=user, tracer=tracer, ) return response.text.strip() @@ -714,11 +714,11 @@ async def extract_relevant_summary( with timer("Chat actor: Extract relevant information from data", logger): response = await send_message_to_model_wrapper( extract_relevant_information, - prompts.system_prompt_extract_relevant_summary, - user=user, query_images=query_images, - agent_chat_model=agent_chat_model, + system_message=prompts.system_prompt_extract_relevant_summary, fast_model=True, + agent_chat_model=agent_chat_model, + user=user, tracer=tracer, ) return response.text.strip() @@ -887,10 +887,10 @@ async def generate_better_diagram_description( response = await send_message_to_model_wrapper( improve_diagram_description_prompt, query_images=query_images, - user=user, query_files=query_files, - agent_chat_model=agent_chat_model, fast_model=False, + agent_chat_model=agent_chat_model, + user=user, tracer=tracer, ) response = response.text.strip() @@ -920,9 +920,9 @@ async def generate_excalidraw_diagram_from_description( with timer("Chat actor: Generate excalidraw diagram", logger): raw_response = await send_message_to_model_wrapper( query=excalidraw_diagram_generation, - user=user, - agent_chat_model=agent_chat_model, fast_model=False, + agent_chat_model=agent_chat_model, + user=user, tracer=tracer, ) raw_response_text = clean_json(raw_response.text) @@ -1042,11 +1042,11 @@ async def generate_better_mermaidjs_diagram_description( with timer("Chat actor: Generate better Mermaid.js diagram description", logger): response = await send_message_to_model_wrapper( improve_diagram_description_prompt, - query_images=query_images, - user=user, query_files=query_files, - agent_chat_model=agent_chat_model, + query_images=query_images, fast_model=False, + agent_chat_model=agent_chat_model, + user=user, tracer=tracer, ) response_text = response.text.strip() @@ -1076,9 +1076,9 @@ async def generate_mermaidjs_diagram_from_description( with timer("Chat actor: Generate Mermaid.js diagram", logger): raw_response = await send_message_to_model_wrapper( query=mermaidjs_diagram_generation, - user=user, - agent_chat_model=agent_chat_model, fast_model=False, + agent_chat_model=agent_chat_model, + user=user, tracer=tracer, ) return clean_mermaidjs(raw_response.text.strip()) @@ -1135,15 +1135,15 @@ async def generate_better_image_prompt( with timer("Chat actor: Generate contextual image prompt", logger): raw_response = await send_message_to_model_wrapper( q, - system_message=enhance_image_system_message, - query_images=query_images, query_files=query_files, + query_images=query_images, chat_history=conversation_history, - agent_chat_model=agent_chat_model, - fast_model=False, - user=user, + system_message=enhance_image_system_message, response_type="json_object", response_schema=ImagePromptResponse, + fast_model=False, + agent_chat_model=agent_chat_model, + user=user, tracer=tracer, ) @@ -1216,11 +1216,11 @@ async def search_documents( inferred_queries = await extract_questions( query=defiltered_query, user=user, - personality_context=personality_context, - chat_history=chat_history, - location_data=location_data, - query_images=query_images, query_files=query_files, + query_images=query_images, + personality_context=personality_context, + location_data=location_data, + chat_history=chat_history, tracer=tracer, ) @@ -1266,11 +1266,11 @@ async def search_documents( async def extract_questions( query: str, user: KhojUser, - personality_context: str = "", - chat_history: List[ChatMessageModel] = [], - location_data: LocationData = None, - query_images: Optional[List[str]] = None, query_files: str = None, + query_images: Optional[List[str]] = None, + personality_context: str = "", + location_data: LocationData = None, + chat_history: List[ChatMessageModel] = [], max_queries: int = 5, tracer: dict = {}, ): @@ -1317,10 +1317,10 @@ async def extract_questions( ) raw_response = await send_message_to_model_wrapper( - system_message=system_prompt, query=prompt, - query_images=query_images, query_files=query_files, + query_images=query_images, + system_message=system_prompt, response_type="json_object", response_schema=DocumentQueries, fast_model=False, @@ -1436,19 +1436,23 @@ async def execute_search( async def send_message_to_model_wrapper( + # Context query: str, + query_files: str = None, + query_images: List[str] = None, + context: str = "", + chat_history: list[ChatMessageModel] = [], system_message: str = "", + # Model Config response_type: str = "text", response_schema: BaseModel = None, tools: List[ToolDefinition] = None, deepthought: bool = False, fast_model: Optional[bool] = None, - user: KhojUser = None, - query_images: List[str] = None, - context: str = "", - query_files: str = None, - chat_history: list[ChatMessageModel] = [], agent_chat_model: ChatModel = None, + # User + user: KhojUser = None, + # Tracer tracer: dict = {}, ): chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model, fast=fast_model) @@ -1472,16 +1476,16 @@ async def send_message_to_model_wrapper( truncated_messages = generate_chatml_messages_with_context( user_message=query, + query_files=query_files, + query_images=query_images, context_message=context, - system_message=system_message, chat_history=chat_history, + system_message=system_message, model_name=chat_model_name, + model_type=model_type, tokenizer_name=tokenizer, max_prompt_size=max_tokens, vision_enabled=vision_available, - query_images=query_images, - model_type=model_type, - query_files=query_files, ) if model_type == ChatModel.ModelType.OPENAI: @@ -1549,14 +1553,14 @@ def send_message_to_model_wrapper_sync( truncated_messages = generate_chatml_messages_with_context( user_message=message, - system_message=system_message, + query_files=query_files, + query_images=query_images, chat_history=chat_history, + system_message=system_message, model_name=chat_model_name, + model_type=model_type, max_prompt_size=max_tokens, vision_enabled=vision_available, - model_type=model_type, - query_images=query_images, - query_files=query_files, ) if model_type == ChatModel.ModelType.OPENAI: diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 1bd6ba71..86980c08 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -159,15 +159,15 @@ async def apick_next_tool( with timer("Chat actor: Infer information sources to refer", logger): response = await send_message_to_model_wrapper( query="", + query_files=query_files, + query_images=query_images, system_message=function_planning_prompt, chat_history=chat_and_research_history, tools=tools, deepthought=True, fast_model=False, - user=user, - query_images=query_images, - query_files=query_files, agent_chat_model=agent_chat_model, + user=user, tracer=tracer, ) except Exception as e: