diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 6ec2376b..41426294 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -762,29 +762,32 @@ Assuming you can search the user's notes and the internet. - User Name: {username} # Available Tool AIs -Which of the tool AIs listed below would you use to answer the user's question? You **only** have access to the following tool AIs: +You decide which of the tool AIs listed below would you use to answer the user's question. You **only** have access to the following tool AIs: {tools} -# Previous Iterations -{previous_iterations} - -# Chat History: -{chat_history} - -Return the next tool AI to use and the query to ask it. Your response should always be a valid JSON object. Do not say anything else. +Your response should always be a valid JSON object. Do not say anything else. Response format: {{"scratchpad": "", "tool": "", "query": ""}} """.strip() ) +plan_function_execution_next_tool = PromptTemplate.from_template( + """ +Given the results of your previous iterations, which tool AI will you use next to answer the target query? + +# Target Query: +{query} +""".strip() +) + previous_iteration = PromptTemplate.from_template( """ -## Iteration {index}: +# Iteration {index}: - tool: {tool} - query: {query} - result: {result} -""" +""".strip() ) pick_relevant_tools = PromptTemplate.from_template( diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 7901f29c..6e4b62ab 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -105,9 +105,9 @@ class InformationCollectionIteration: def construct_iteration_history( - previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str -) -> str: - previous_iterations_history = "" + query: str, previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str +) -> list[dict]: + previous_iterations_history = [] for idx, iteration in enumerate(previous_iterations): iteration_data = previous_iteration_prompt.format( tool=iteration.tool, @@ -116,8 +116,23 @@ def construct_iteration_history( index=idx + 1, ) - previous_iterations_history += iteration_data - return previous_iterations_history + previous_iterations_history.append(iteration_data) + + return ( + [ + { + "by": "you", + "message": query, + }, + { + "by": "khoj", + "intent": {"type": "remember", "query": query}, + "message": previous_iterations_history, + }, + ] + if previous_iterations_history + else [] + ) def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: @@ -316,7 +331,11 @@ Khoj: "{chat_response}" def construct_structured_message( - message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str = None + message: list[str] | str, + images: list[str], + model_type: str, + vision_enabled: bool, + attached_file_context: str = None, ): """ Format messages into appropriate multimedia format for supported chat model types @@ -326,10 +345,11 @@ def construct_structured_message( ChatModel.ModelType.GOOGLE, ChatModel.ModelType.ANTHROPIC, ]: - if not attached_file_context and not (vision_enabled and images): - return message + message = [message] if isinstance(message, str) else message - constructed_messages: List[Any] = [{"type": "text", "text": message}] + constructed_messages: List[dict[str, Any]] = [ + {"type": "text", "text": message_part} for message_part in message + ] if not is_none_or_empty(attached_file_context): constructed_messages.append({"type": "text", "text": attached_file_context}) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 86f63b2d..62f24282 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -108,13 +108,6 @@ async def apick_next_tool( # Create planning reponse model with dynamically populated tool enum class planning_response_model = PlanningResponse.create_model_with_enum(tool_options) - # Construct chat history with user and iteration history with researcher agent for context - chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj") - previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) - - if query_images: - query = f"[placeholder for user attached images]\n{query}" - today = datetime.today() location_data = f"{location}" if location else "Unknown" agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None @@ -124,21 +117,30 @@ async def apick_next_tool( function_planning_prompt = prompts.plan_function_execution.format( tools=tool_options_str, - chat_history=chat_history, personality_context=personality_context, current_date=today.strftime("%Y-%m-%d"), day_of_week=today.strftime("%A"), username=user_name or "Unknown", location=location_data, - previous_iterations=previous_iterations_history, max_iterations=max_iterations, ) + if query_images: + query = f"[placeholder for user attached images]\n{query}" + + # Construct chat history with user and iteration history with researcher agent for context + previous_iterations_history = construct_iteration_history(query, previous_iterations, prompts.previous_iteration) + iteration_chat_log = {"chat": conversation_history.get("chat", []) + previous_iterations_history} + + # Plan function execution for the next tool + query = prompts.plan_function_execution_next_tool.format(query=query) if previous_iterations_history else query + try: with timer("Chat actor: Infer information sources to refer", logger): response = await send_message_to_model_wrapper( query=query, - context=function_planning_prompt, + system_message=function_planning_prompt, + conversation_log=iteration_chat_log, response_type="json_object", response_schema=planning_response_model, deepthought=True,