diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py
index d9ed7901..cef6f49d 100644
--- a/src/khoj/routers/research.py
+++ b/src/khoj/routers/research.py
@@ -48,6 +48,261 @@ from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
+class ToolExecutionResult:
+ """Result of executing a single tool call"""
+
+ def __init__(self):
+ self.status_messages: List = []
+ self.document_results: List[Dict[str, str]] = []
+ self.online_results: Dict = {}
+ self.code_results: Dict = {}
+ self.operator_results: OperatorRun = None
+ self.mcp_results: List = []
+ self.should_terminate: bool = False
+
+
+async def execute_tool(
+ iteration: ResearchIteration,
+ user: KhojUser,
+ conversation_id: str,
+ previous_iterations: List[ResearchIteration],
+ location: LocationData,
+ query_images: List[str],
+ query_files: str,
+ max_document_searches: int,
+ max_online_searches: int,
+ mcp_clients: List[MCPClient],
+ cancellation_event: Optional[asyncio.Event],
+ interrupt_queue: Optional[asyncio.Queue],
+ agent: Agent,
+ tracer: dict,
+) -> ToolExecutionResult:
+ """Execute a single tool call and return results. Designed for parallel execution."""
+ result = ToolExecutionResult()
+
+ # Skip if warning present
+ if iteration.warning:
+ logger.warning(f"Research mode: {iteration.warning}.")
+ return result
+
+ # Check for termination conditions
+ if not iteration.query or isinstance(iteration.query, str) or iteration.query.name == ConversationCommand.Text:
+ result.should_terminate = True
+ return result
+
+ # Create a status collector that captures messages for later batch yield
+ async def status_collector(message: str):
+ """Async generator that collects status messages instead of streaming them."""
+ # Just collect the message - we'll process it later in the main loop
+ result.status_messages.append(message)
+ yield message # Yield to satisfy async generator protocol expected by tool functions
+
+ try:
+ if iteration.query.name == ConversationCommand.SemanticSearchFiles:
+ iteration.context = []
+ previous_inferred_queries = {
+ c["query"] for iter in previous_iterations if iter.context for c in iter.context
+ }
+ async for res in search_documents(
+ **iteration.query.args,
+ n=max_document_searches,
+ d=None,
+ user=user,
+ chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SemanticSearchFiles),
+ conversation_id=conversation_id,
+ conversation_commands=[ConversationCommand.Notes],
+ location_data=location,
+ send_status_func=status_collector,
+ query_images=query_images,
+ query_files=query_files,
+ previous_inferred_queries=previous_inferred_queries,
+ agent=agent,
+ tracer=tracer,
+ ):
+ # Status messages are collected by status_collector, skip ChatEvent.STATUS here
+ if isinstance(res, tuple):
+ result.document_results = res[0]
+ iteration.context += result.document_results
+
+ if not is_none_or_empty(result.document_results):
+ try:
+ distinct_files = {d["file"] for d in result.document_results}
+ distinct_headings = set(
+ [d["compiled"].split("\n")[0] for d in result.document_results if "compiled" in d]
+ )
+ headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
+ async for _ in status_collector(
+ f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
+ ):
+ pass
+ except Exception as e:
+ iteration.warning = f"Error extracting document references: {e}"
+ logger.error(iteration.warning, exc_info=True)
+ else:
+ iteration.warning = "No matching document references found"
+
+ elif iteration.query.name == ConversationCommand.SearchWeb:
+ previous_subqueries = {
+ subquery for iter in previous_iterations if iter.onlineContext for subquery in iter.onlineContext.keys()
+ }
+ async for res in search_online(
+ **iteration.query.args,
+ conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SearchWeb),
+ location=location,
+ user=user,
+ send_status_func=status_collector,
+ custom_filters=[],
+ max_online_searches=max_online_searches,
+ max_webpages_to_read=0,
+ query_images=query_images,
+ previous_subqueries=previous_subqueries,
+ agent=agent,
+ tracer=tracer,
+ ):
+ # Status messages are collected by status_collector, skip ChatEvent.STATUS here
+ if is_none_or_empty(res):
+ iteration.warning = (
+ "Detected previously run online search queries. Skipping iteration. Try something different."
+ )
+ elif not (isinstance(res, dict) and ChatEvent.STATUS in res):
+ result.online_results = res
+ iteration.onlineContext = result.online_results
+
+ elif iteration.query.name == ConversationCommand.ReadWebpage:
+ async for res in read_webpages_content(
+ **iteration.query.args,
+ user=user,
+ send_status_func=status_collector,
+ agent=agent,
+ tracer=tracer,
+ ):
+ # Status messages are collected by status_collector, skip ChatEvent.STATUS here
+ if not (isinstance(res, dict) and ChatEvent.STATUS in res):
+ direct_web_pages: Dict[str, Dict] = res
+ for web_query in direct_web_pages:
+ if result.online_results.get(web_query):
+ result.online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"]
+ else:
+ result.online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]}
+ iteration.onlineContext = result.online_results
+
+ elif iteration.query.name == ConversationCommand.PythonCoder:
+ async for res in run_code(
+ **iteration.query.args,
+ conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.PythonCoder),
+ context="",
+ location_data=location,
+ user=user,
+ send_status_func=status_collector,
+ query_images=query_images,
+ query_files=query_files,
+ agent=agent,
+ tracer=tracer,
+ ):
+ # Status messages are collected by status_collector, skip ChatEvent.STATUS here
+ if not (isinstance(res, dict) and ChatEvent.STATUS in res):
+ result.code_results = res
+ iteration.codeContext = result.code_results
+ if iteration.codeContext:
+ async for _ in status_collector(f"**Ran code snippets**: {len(iteration.codeContext)}"):
+ pass
+
+ elif iteration.query.name == ConversationCommand.OperateComputer:
+ async for res in operate_environment(
+ **iteration.query.args,
+ user=user,
+ conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
+ location_data=location,
+ previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
+ send_status_func=status_collector,
+ query_images=query_images,
+ agent=agent,
+ query_files=query_files,
+ cancellation_event=cancellation_event,
+ interrupt_queue=interrupt_queue,
+ tracer=tracer,
+ ):
+ # Status messages are collected by status_collector, skip ChatEvent.STATUS here
+ if isinstance(res, OperatorRun):
+ result.operator_results = res
+ iteration.operatorContext = result.operator_results
+ if res.webpages:
+ if not result.online_results.get(iteration.query):
+ result.online_results[iteration.query] = {"webpages": res.webpages}
+ elif not result.online_results[iteration.query].get("webpages"):
+ result.online_results[iteration.query]["webpages"] = res.webpages
+ else:
+ result.online_results[iteration.query]["webpages"] += res.webpages
+ iteration.onlineContext = result.online_results
+
+ elif iteration.query.name == ConversationCommand.ViewFile:
+ async for res in view_file_content(
+ **iteration.query.args,
+ user=user,
+ ):
+ # Status messages are collected by status_collector, skip ChatEvent.STATUS here
+ if not (isinstance(res, dict) and ChatEvent.STATUS in res):
+ if iteration.context is None:
+ iteration.context = []
+ result.document_results = res
+ iteration.context += result.document_results
+ async for _ in status_collector(f"**Viewed file**: {iteration.query.args['path']}"):
+ pass
+
+ elif iteration.query.name == ConversationCommand.ListFiles:
+ async for res in list_files(
+ **iteration.query.args,
+ user=user,
+ ):
+ # Status messages are collected by status_collector, skip ChatEvent.STATUS here
+ if not (isinstance(res, dict) and ChatEvent.STATUS in res):
+ if iteration.context is None:
+ iteration.context = []
+ result.document_results = [res]
+ iteration.context += result.document_results
+ if result.document_results:
+ async for _ in status_collector(result.document_results[-1].get("query", "Listed files")):
+ pass
+
+ elif iteration.query.name == ConversationCommand.RegexSearchFiles:
+ async for res in grep_files(
+ **iteration.query.args,
+ user=user,
+ ):
+ # Status messages are collected by status_collector, skip ChatEvent.STATUS here
+ if not (isinstance(res, dict) and ChatEvent.STATUS in res):
+ if iteration.context is None:
+ iteration.context = []
+ result.document_results = [res]
+ iteration.context += result.document_results
+ if result.document_results:
+ async for _ in status_collector(result.document_results[-1].get("query", "Searched files")):
+ pass
+
+ elif "/" in iteration.query.name:
+ server_name, tool_name = iteration.query.name.split("/", 1)
+ mcp_client = next((client for client in mcp_clients if client.name == server_name), None)
+ if not mcp_client:
+ raise ValueError(f"Could not find MCP server with name {server_name}")
+
+ result.mcp_results = await mcp_client.run_tool(tool_name, iteration.query.args)
+ if iteration.context is None:
+ iteration.context = []
+ iteration.context += result.mcp_results
+ async for _ in status_collector(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"):
+ pass
+
+ else:
+ result.should_terminate = True
+
+ except Exception as e:
+ tool_name = iteration.query.name if iteration.query else "unknown"
+ iteration.warning = f"Error executing {tool_name}: {e}"
+ logger.error(iteration.warning, exc_info=True)
+
+ return result
+
+
async def apick_next_tool(
query: str,
conversation_history: List[ChatMessageModel],
@@ -331,316 +586,83 @@ async def research(
iterations_to_process.append(result)
yield result
- # Process all tool calls from this planning step
- for this_iteration in iterations_to_process:
- online_results: Dict = dict()
- code_results: Dict = dict()
- document_results: List[Dict[str, str]] = []
- operator_results: OperatorRun = None
- mcp_results: List = []
-
- # Skip running iteration if warning present in iteration
- if this_iteration.warning:
- logger.warning(f"Research mode: {this_iteration.warning}.")
-
- # Terminate research if selected text tool or query, tool not set for next iteration
- elif (
- not this_iteration.query
- or isinstance(this_iteration.query, str)
- or this_iteration.query.name == ConversationCommand.Text
- ):
- current_iteration = MAX_ITERATIONS
-
- elif this_iteration.query.name == ConversationCommand.SemanticSearchFiles:
- this_iteration.context = []
- document_results = []
- previous_inferred_queries = {
- c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
- }
- async for result in search_documents(
- **this_iteration.query.args,
- n=max_document_searches,
- d=None,
+ if iterations_to_process:
+ # Create tasks for parallel execution
+ tasks = [
+ execute_tool(
+ iteration=iteration,
user=user,
- chat_history=construct_tool_chat_history(
- previous_iterations, ConversationCommand.SemanticSearchFiles
- ),
conversation_id=conversation_id,
- conversation_commands=[ConversationCommand.Notes],
- location_data=location,
- send_status_func=send_status_func,
+ previous_iterations=previous_iterations,
+ location=location,
query_images=query_images,
query_files=query_files,
- previous_inferred_queries=previous_inferred_queries,
+ max_document_searches=max_document_searches,
+ max_online_searches=max_online_searches,
+ mcp_clients=mcp_clients,
+ cancellation_event=cancellation_event,
+ interrupt_queue=interrupt_queue,
agent=agent,
tracer=tracer,
+ )
+ for iteration in iterations_to_process
+ ]
+
+ # Execute all tools in parallel
+ tool_results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ # Process results and yield status messages
+ for this_iteration, tool_result in zip(iterations_to_process, tool_results):
+ # Handle exceptions from asyncio.gather
+ if isinstance(tool_result, Exception):
+ this_iteration.warning = f"Error executing tool: {tool_result}"
+ logger.error(this_iteration.warning, exc_info=True)
+ tool_result = ToolExecutionResult()
+
+ # Check for termination
+ if tool_result.should_terminate:
+ current_iteration = MAX_ITERATIONS
+
+ # Yield all collected status messages through the real send_status_func
+ for status_msg in tool_result.status_messages:
+ if send_status_func:
+ async for status_event in send_status_func(status_msg):
+ yield status_event
+
+ current_iteration += 1
+
+ # Build summarized results
+ if (
+ tool_result.document_results
+ or tool_result.online_results
+ or tool_result.code_results
+ or tool_result.operator_results
+ or tool_result.mcp_results
+ or this_iteration.warning
):
- if isinstance(result, dict) and ChatEvent.STATUS in result:
- yield result[ChatEvent.STATUS]
- elif isinstance(result, tuple):
- document_results = result[0]
- this_iteration.context += document_results
+ results_data = f"\n"
+ if tool_result.document_results:
+ results_data += f"\n\n{yaml.dump(tool_result.document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
+ if tool_result.online_results:
+ results_data += f"\n\n{yaml.dump(tool_result.online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
+ if tool_result.code_results:
+ results_data += f"\n\n{yaml.dump(truncate_code_context(tool_result.code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
+ if tool_result.operator_results:
+ results_data += f"\n\n{tool_result.operator_results.response}\n"
+ if tool_result.mcp_results:
+ results_data += f"\n\n{yaml.dump(tool_result.mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
+ if this_iteration.warning:
+ results_data += f"\n\n{this_iteration.warning}\n"
+ results_data += f"\n"
- if not is_none_or_empty(document_results):
- try:
- distinct_files = {d["file"] for d in document_results}
- distinct_headings = set(
- [d["compiled"].split("\n")[0] for d in document_results if "compiled" in d]
- )
- # Strip only leading # from headings
- headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
- async for result in send_status_func(
- f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
- ):
- yield result
- except Exception as e:
- this_iteration.warning = f"Error extracting document references: {e}"
- logger.error(this_iteration.warning, exc_info=True)
- else:
- this_iteration.warning = "No matching document references found"
+ this_iteration.summarizedResult = results_data
- elif this_iteration.query.name == ConversationCommand.SearchWeb:
- previous_subqueries = {
- subquery
- for iteration in previous_iterations
- if iteration.onlineContext
- for subquery in iteration.onlineContext.keys()
- }
- try:
- async for result in search_online(
- **this_iteration.query.args,
- conversation_history=construct_tool_chat_history(
- previous_iterations, ConversationCommand.SearchWeb
- ),
- location=location,
- user=user,
- send_status_func=send_status_func,
- custom_filters=[],
- max_online_searches=max_online_searches,
- max_webpages_to_read=0,
- query_images=query_images,
- previous_subqueries=previous_subqueries,
- agent=agent,
- tracer=tracer,
- ):
- if isinstance(result, dict) and ChatEvent.STATUS in result:
- yield result[ChatEvent.STATUS]
- elif is_none_or_empty(result):
- this_iteration.warning = "Detected previously run online search queries. Skipping iteration. Try something different."
- else:
- online_results: Dict[str, Dict] = result # type: ignore
- this_iteration.onlineContext = online_results
- except Exception as e:
- this_iteration.warning = f"Error searching online: {e}"
- logger.error(this_iteration.warning, exc_info=True)
-
- elif this_iteration.query.name == ConversationCommand.ReadWebpage:
- try:
- async for result in read_webpages_content(
- **this_iteration.query.args,
- user=user,
- send_status_func=send_status_func,
- agent=agent,
- tracer=tracer,
- ):
- if isinstance(result, dict) and ChatEvent.STATUS in result:
- yield result[ChatEvent.STATUS]
- else:
- direct_web_pages: Dict[str, Dict] = result # type: ignore
-
- webpages = []
- for web_query in direct_web_pages:
- if online_results.get(web_query):
- online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"]
- else:
- online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]}
-
- for webpage in direct_web_pages[web_query]["webpages"]:
- webpages.append(webpage["link"])
- this_iteration.onlineContext = online_results
- except Exception as e:
- this_iteration.warning = f"Error reading webpages: {e}"
- logger.error(this_iteration.warning, exc_info=True)
-
- elif this_iteration.query.name == ConversationCommand.PythonCoder:
- try:
- async for result in run_code(
- **this_iteration.query.args,
- conversation_history=construct_tool_chat_history(
- previous_iterations, ConversationCommand.PythonCoder
- ),
- context="",
- location_data=location,
- user=user,
- send_status_func=send_status_func,
- query_images=query_images,
- query_files=query_files,
- agent=agent,
- tracer=tracer,
- ):
- if isinstance(result, dict) and ChatEvent.STATUS in result:
- yield result[ChatEvent.STATUS]
- else:
- code_results: Dict[str, Dict] = result # type: ignore
- this_iteration.codeContext = code_results
- async for result in send_status_func(f"**Ran code snippets**: {len(this_iteration.codeContext)}"):
- yield result
- except (ValueError, TypeError) as e:
- this_iteration.warning = f"Error running code: {e}"
- logger.warning(this_iteration.warning, exc_info=True)
-
- elif this_iteration.query.name == ConversationCommand.OperateComputer:
- try:
- async for result in operate_environment(
- **this_iteration.query.args,
- user=user,
- conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
- location_data=location,
- previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
- send_status_func=send_status_func,
- query_images=query_images,
- agent=agent,
- query_files=query_files,
- cancellation_event=cancellation_event,
- interrupt_queue=interrupt_queue,
- tracer=tracer,
- ):
- if isinstance(result, dict) and ChatEvent.STATUS in result:
- yield result[ChatEvent.STATUS]
- elif isinstance(result, OperatorRun):
- operator_results = result
- this_iteration.operatorContext = operator_results
- # Add webpages visited while operating browser to references
- if result.webpages:
- if not online_results.get(this_iteration.query):
- online_results[this_iteration.query] = {"webpages": result.webpages}
- elif not online_results[this_iteration.query].get("webpages"):
- online_results[this_iteration.query]["webpages"] = result.webpages
- else:
- online_results[this_iteration.query]["webpages"] += result.webpages
- this_iteration.onlineContext = online_results
- except Exception as e:
- this_iteration.warning = f"Error operating browser: {e}"
- logger.error(this_iteration.warning, exc_info=True)
-
- elif this_iteration.query.name == ConversationCommand.ViewFile:
- try:
- async for result in view_file_content(
- **this_iteration.query.args,
- user=user,
- ):
- if isinstance(result, dict) and ChatEvent.STATUS in result:
- yield result[ChatEvent.STATUS]
- else:
- if this_iteration.context is None:
- this_iteration.context = []
- document_results: List[Dict[str, str]] = result # type: ignore
- this_iteration.context += document_results
- async for result in send_status_func(f"**Viewed file**: {this_iteration.query.args['path']}"):
- yield result
- except Exception as e:
- this_iteration.warning = f"Error viewing file: {e}"
- logger.error(this_iteration.warning, exc_info=True)
-
- elif this_iteration.query.name == ConversationCommand.ListFiles:
- try:
- async for result in list_files(
- **this_iteration.query.args,
- user=user,
- ):
- if isinstance(result, dict) and ChatEvent.STATUS in result:
- yield result[ChatEvent.STATUS]
- else:
- if this_iteration.context is None:
- this_iteration.context = []
- document_results: List[Dict[str, str]] = [result] # type: ignore
- this_iteration.context += document_results
- async for result in send_status_func(result["query"]):
- yield result
- except Exception as e:
- this_iteration.warning = f"Error listing files: {e}"
- logger.error(this_iteration.warning, exc_info=True)
-
- elif this_iteration.query.name == ConversationCommand.RegexSearchFiles:
- try:
- async for result in grep_files(
- **this_iteration.query.args,
- user=user,
- ):
- if isinstance(result, dict) and ChatEvent.STATUS in result:
- yield result[ChatEvent.STATUS]
- else:
- if this_iteration.context is None:
- this_iteration.context = []
- document_results: List[Dict[str, str]] = [result] # type: ignore
- this_iteration.context += document_results
- async for result in send_status_func(result["query"]):
- yield result
- except Exception as e:
- this_iteration.warning = f"Error searching with regex: {e}"
- logger.error(this_iteration.warning, exc_info=True)
-
- elif "/" in this_iteration.query.name:
- try:
- # Identify MCP client to use
- server_name, tool_name = this_iteration.query.name.split("/", 1)
- mcp_client = next((client for client in mcp_clients if client.name == server_name), None)
- if not mcp_client:
- raise ValueError(f"Could not find MCP server with name {server_name}")
-
- # Invoke tool on the identified MCP server
- mcp_results = await mcp_client.run_tool(tool_name, this_iteration.query.args)
-
- # Record tool result in context
- if this_iteration.context is None:
- this_iteration.context = []
- this_iteration.context += mcp_results
- async for result in send_status_func(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"):
- yield result
- except Exception as e:
- this_iteration.warning = f"Error using MCP tool: {e}"
- logger.error(this_iteration.warning, exc_info=True)
-
- else:
- # No valid tools. This is our exit condition.
- current_iteration = MAX_ITERATIONS
-
- current_iteration += 1
-
- if (
- document_results
- or online_results
- or code_results
- or operator_results
- or mcp_results
- or this_iteration.warning
- ):
- results_data = f"\n"
- if document_results:
- results_data += f"\n\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
- if online_results:
- results_data += f"\n\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
- if code_results:
- results_data += f"\n\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
- if operator_results:
- results_data += (
- f"\n\n{operator_results.response}\n"
- )
- if mcp_results:
- results_data += f"\n\n{yaml.dump(mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
- if this_iteration.warning:
- results_data += f"\n\n{this_iteration.warning}\n"
- results_data += f"\n"
-
- # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
- this_iteration.summarizedResult = results_data
-
- this_iteration.summarizedResult = (
- this_iteration.summarizedResult
- or f"Failed to get results."
- )
- previous_iterations.append(this_iteration)
- yield this_iteration
+ this_iteration.summarizedResult = (
+ this_iteration.summarizedResult
+ or f"Failed to get results."
+ )
+ previous_iterations.append(this_iteration)
+ yield this_iteration
# Close MCP client connections
for mcp_client in mcp_clients: