diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index d9ed7901..cef6f49d 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -48,6 +48,261 @@ from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) +class ToolExecutionResult: + """Result of executing a single tool call""" + + def __init__(self): + self.status_messages: List = [] + self.document_results: List[Dict[str, str]] = [] + self.online_results: Dict = {} + self.code_results: Dict = {} + self.operator_results: OperatorRun = None + self.mcp_results: List = [] + self.should_terminate: bool = False + + +async def execute_tool( + iteration: ResearchIteration, + user: KhojUser, + conversation_id: str, + previous_iterations: List[ResearchIteration], + location: LocationData, + query_images: List[str], + query_files: str, + max_document_searches: int, + max_online_searches: int, + mcp_clients: List[MCPClient], + cancellation_event: Optional[asyncio.Event], + interrupt_queue: Optional[asyncio.Queue], + agent: Agent, + tracer: dict, +) -> ToolExecutionResult: + """Execute a single tool call and return results. Designed for parallel execution.""" + result = ToolExecutionResult() + + # Skip if warning present + if iteration.warning: + logger.warning(f"Research mode: {iteration.warning}.") + return result + + # Check for termination conditions + if not iteration.query or isinstance(iteration.query, str) or iteration.query.name == ConversationCommand.Text: + result.should_terminate = True + return result + + # Create a status collector that captures messages for later batch yield + async def status_collector(message: str): + """Async generator that collects status messages instead of streaming them.""" + # Just collect the message - we'll process it later in the main loop + result.status_messages.append(message) + yield message # Yield to satisfy async generator protocol expected by tool functions + + try: + if iteration.query.name == ConversationCommand.SemanticSearchFiles: + iteration.context = [] + previous_inferred_queries = { + c["query"] for iter in previous_iterations if iter.context for c in iter.context + } + async for res in search_documents( + **iteration.query.args, + n=max_document_searches, + d=None, + user=user, + chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SemanticSearchFiles), + conversation_id=conversation_id, + conversation_commands=[ConversationCommand.Notes], + location_data=location, + send_status_func=status_collector, + query_images=query_images, + query_files=query_files, + previous_inferred_queries=previous_inferred_queries, + agent=agent, + tracer=tracer, + ): + # Status messages are collected by status_collector, skip ChatEvent.STATUS here + if isinstance(res, tuple): + result.document_results = res[0] + iteration.context += result.document_results + + if not is_none_or_empty(result.document_results): + try: + distinct_files = {d["file"] for d in result.document_results} + distinct_headings = set( + [d["compiled"].split("\n")[0] for d in result.document_results if "compiled" in d] + ) + headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "") + async for _ in status_collector( + f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}" + ): + pass + except Exception as e: + iteration.warning = f"Error extracting document references: {e}" + logger.error(iteration.warning, exc_info=True) + else: + iteration.warning = "No matching document references found" + + elif iteration.query.name == ConversationCommand.SearchWeb: + previous_subqueries = { + subquery for iter in previous_iterations if iter.onlineContext for subquery in iter.onlineContext.keys() + } + async for res in search_online( + **iteration.query.args, + conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SearchWeb), + location=location, + user=user, + send_status_func=status_collector, + custom_filters=[], + max_online_searches=max_online_searches, + max_webpages_to_read=0, + query_images=query_images, + previous_subqueries=previous_subqueries, + agent=agent, + tracer=tracer, + ): + # Status messages are collected by status_collector, skip ChatEvent.STATUS here + if is_none_or_empty(res): + iteration.warning = ( + "Detected previously run online search queries. Skipping iteration. Try something different." + ) + elif not (isinstance(res, dict) and ChatEvent.STATUS in res): + result.online_results = res + iteration.onlineContext = result.online_results + + elif iteration.query.name == ConversationCommand.ReadWebpage: + async for res in read_webpages_content( + **iteration.query.args, + user=user, + send_status_func=status_collector, + agent=agent, + tracer=tracer, + ): + # Status messages are collected by status_collector, skip ChatEvent.STATUS here + if not (isinstance(res, dict) and ChatEvent.STATUS in res): + direct_web_pages: Dict[str, Dict] = res + for web_query in direct_web_pages: + if result.online_results.get(web_query): + result.online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"] + else: + result.online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]} + iteration.onlineContext = result.online_results + + elif iteration.query.name == ConversationCommand.PythonCoder: + async for res in run_code( + **iteration.query.args, + conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.PythonCoder), + context="", + location_data=location, + user=user, + send_status_func=status_collector, + query_images=query_images, + query_files=query_files, + agent=agent, + tracer=tracer, + ): + # Status messages are collected by status_collector, skip ChatEvent.STATUS here + if not (isinstance(res, dict) and ChatEvent.STATUS in res): + result.code_results = res + iteration.codeContext = result.code_results + if iteration.codeContext: + async for _ in status_collector(f"**Ran code snippets**: {len(iteration.codeContext)}"): + pass + + elif iteration.query.name == ConversationCommand.OperateComputer: + async for res in operate_environment( + **iteration.query.args, + user=user, + conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator), + location_data=location, + previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None, + send_status_func=status_collector, + query_images=query_images, + agent=agent, + query_files=query_files, + cancellation_event=cancellation_event, + interrupt_queue=interrupt_queue, + tracer=tracer, + ): + # Status messages are collected by status_collector, skip ChatEvent.STATUS here + if isinstance(res, OperatorRun): + result.operator_results = res + iteration.operatorContext = result.operator_results + if res.webpages: + if not result.online_results.get(iteration.query): + result.online_results[iteration.query] = {"webpages": res.webpages} + elif not result.online_results[iteration.query].get("webpages"): + result.online_results[iteration.query]["webpages"] = res.webpages + else: + result.online_results[iteration.query]["webpages"] += res.webpages + iteration.onlineContext = result.online_results + + elif iteration.query.name == ConversationCommand.ViewFile: + async for res in view_file_content( + **iteration.query.args, + user=user, + ): + # Status messages are collected by status_collector, skip ChatEvent.STATUS here + if not (isinstance(res, dict) and ChatEvent.STATUS in res): + if iteration.context is None: + iteration.context = [] + result.document_results = res + iteration.context += result.document_results + async for _ in status_collector(f"**Viewed file**: {iteration.query.args['path']}"): + pass + + elif iteration.query.name == ConversationCommand.ListFiles: + async for res in list_files( + **iteration.query.args, + user=user, + ): + # Status messages are collected by status_collector, skip ChatEvent.STATUS here + if not (isinstance(res, dict) and ChatEvent.STATUS in res): + if iteration.context is None: + iteration.context = [] + result.document_results = [res] + iteration.context += result.document_results + if result.document_results: + async for _ in status_collector(result.document_results[-1].get("query", "Listed files")): + pass + + elif iteration.query.name == ConversationCommand.RegexSearchFiles: + async for res in grep_files( + **iteration.query.args, + user=user, + ): + # Status messages are collected by status_collector, skip ChatEvent.STATUS here + if not (isinstance(res, dict) and ChatEvent.STATUS in res): + if iteration.context is None: + iteration.context = [] + result.document_results = [res] + iteration.context += result.document_results + if result.document_results: + async for _ in status_collector(result.document_results[-1].get("query", "Searched files")): + pass + + elif "/" in iteration.query.name: + server_name, tool_name = iteration.query.name.split("/", 1) + mcp_client = next((client for client in mcp_clients if client.name == server_name), None) + if not mcp_client: + raise ValueError(f"Could not find MCP server with name {server_name}") + + result.mcp_results = await mcp_client.run_tool(tool_name, iteration.query.args) + if iteration.context is None: + iteration.context = [] + iteration.context += result.mcp_results + async for _ in status_collector(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"): + pass + + else: + result.should_terminate = True + + except Exception as e: + tool_name = iteration.query.name if iteration.query else "unknown" + iteration.warning = f"Error executing {tool_name}: {e}" + logger.error(iteration.warning, exc_info=True) + + return result + + async def apick_next_tool( query: str, conversation_history: List[ChatMessageModel], @@ -331,316 +586,83 @@ async def research( iterations_to_process.append(result) yield result - # Process all tool calls from this planning step - for this_iteration in iterations_to_process: - online_results: Dict = dict() - code_results: Dict = dict() - document_results: List[Dict[str, str]] = [] - operator_results: OperatorRun = None - mcp_results: List = [] - - # Skip running iteration if warning present in iteration - if this_iteration.warning: - logger.warning(f"Research mode: {this_iteration.warning}.") - - # Terminate research if selected text tool or query, tool not set for next iteration - elif ( - not this_iteration.query - or isinstance(this_iteration.query, str) - or this_iteration.query.name == ConversationCommand.Text - ): - current_iteration = MAX_ITERATIONS - - elif this_iteration.query.name == ConversationCommand.SemanticSearchFiles: - this_iteration.context = [] - document_results = [] - previous_inferred_queries = { - c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context - } - async for result in search_documents( - **this_iteration.query.args, - n=max_document_searches, - d=None, + if iterations_to_process: + # Create tasks for parallel execution + tasks = [ + execute_tool( + iteration=iteration, user=user, - chat_history=construct_tool_chat_history( - previous_iterations, ConversationCommand.SemanticSearchFiles - ), conversation_id=conversation_id, - conversation_commands=[ConversationCommand.Notes], - location_data=location, - send_status_func=send_status_func, + previous_iterations=previous_iterations, + location=location, query_images=query_images, query_files=query_files, - previous_inferred_queries=previous_inferred_queries, + max_document_searches=max_document_searches, + max_online_searches=max_online_searches, + mcp_clients=mcp_clients, + cancellation_event=cancellation_event, + interrupt_queue=interrupt_queue, agent=agent, tracer=tracer, + ) + for iteration in iterations_to_process + ] + + # Execute all tools in parallel + tool_results = await asyncio.gather(*tasks, return_exceptions=True) + + # Process results and yield status messages + for this_iteration, tool_result in zip(iterations_to_process, tool_results): + # Handle exceptions from asyncio.gather + if isinstance(tool_result, Exception): + this_iteration.warning = f"Error executing tool: {tool_result}" + logger.error(this_iteration.warning, exc_info=True) + tool_result = ToolExecutionResult() + + # Check for termination + if tool_result.should_terminate: + current_iteration = MAX_ITERATIONS + + # Yield all collected status messages through the real send_status_func + for status_msg in tool_result.status_messages: + if send_status_func: + async for status_event in send_status_func(status_msg): + yield status_event + + current_iteration += 1 + + # Build summarized results + if ( + tool_result.document_results + or tool_result.online_results + or tool_result.code_results + or tool_result.operator_results + or tool_result.mcp_results + or this_iteration.warning ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - elif isinstance(result, tuple): - document_results = result[0] - this_iteration.context += document_results + results_data = f"\n" + if tool_result.document_results: + results_data += f"\n\n{yaml.dump(tool_result.document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + if tool_result.online_results: + results_data += f"\n\n{yaml.dump(tool_result.online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + if tool_result.code_results: + results_data += f"\n\n{yaml.dump(truncate_code_context(tool_result.code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + if tool_result.operator_results: + results_data += f"\n\n{tool_result.operator_results.response}\n" + if tool_result.mcp_results: + results_data += f"\n\n{yaml.dump(tool_result.mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + if this_iteration.warning: + results_data += f"\n\n{this_iteration.warning}\n" + results_data += f"\n" - if not is_none_or_empty(document_results): - try: - distinct_files = {d["file"] for d in document_results} - distinct_headings = set( - [d["compiled"].split("\n")[0] for d in document_results if "compiled" in d] - ) - # Strip only leading # from headings - headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "") - async for result in send_status_func( - f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}" - ): - yield result - except Exception as e: - this_iteration.warning = f"Error extracting document references: {e}" - logger.error(this_iteration.warning, exc_info=True) - else: - this_iteration.warning = "No matching document references found" + this_iteration.summarizedResult = results_data - elif this_iteration.query.name == ConversationCommand.SearchWeb: - previous_subqueries = { - subquery - for iteration in previous_iterations - if iteration.onlineContext - for subquery in iteration.onlineContext.keys() - } - try: - async for result in search_online( - **this_iteration.query.args, - conversation_history=construct_tool_chat_history( - previous_iterations, ConversationCommand.SearchWeb - ), - location=location, - user=user, - send_status_func=send_status_func, - custom_filters=[], - max_online_searches=max_online_searches, - max_webpages_to_read=0, - query_images=query_images, - previous_subqueries=previous_subqueries, - agent=agent, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - elif is_none_or_empty(result): - this_iteration.warning = "Detected previously run online search queries. Skipping iteration. Try something different." - else: - online_results: Dict[str, Dict] = result # type: ignore - this_iteration.onlineContext = online_results - except Exception as e: - this_iteration.warning = f"Error searching online: {e}" - logger.error(this_iteration.warning, exc_info=True) - - elif this_iteration.query.name == ConversationCommand.ReadWebpage: - try: - async for result in read_webpages_content( - **this_iteration.query.args, - user=user, - send_status_func=send_status_func, - agent=agent, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - direct_web_pages: Dict[str, Dict] = result # type: ignore - - webpages = [] - for web_query in direct_web_pages: - if online_results.get(web_query): - online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"] - else: - online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]} - - for webpage in direct_web_pages[web_query]["webpages"]: - webpages.append(webpage["link"]) - this_iteration.onlineContext = online_results - except Exception as e: - this_iteration.warning = f"Error reading webpages: {e}" - logger.error(this_iteration.warning, exc_info=True) - - elif this_iteration.query.name == ConversationCommand.PythonCoder: - try: - async for result in run_code( - **this_iteration.query.args, - conversation_history=construct_tool_chat_history( - previous_iterations, ConversationCommand.PythonCoder - ), - context="", - location_data=location, - user=user, - send_status_func=send_status_func, - query_images=query_images, - query_files=query_files, - agent=agent, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - code_results: Dict[str, Dict] = result # type: ignore - this_iteration.codeContext = code_results - async for result in send_status_func(f"**Ran code snippets**: {len(this_iteration.codeContext)}"): - yield result - except (ValueError, TypeError) as e: - this_iteration.warning = f"Error running code: {e}" - logger.warning(this_iteration.warning, exc_info=True) - - elif this_iteration.query.name == ConversationCommand.OperateComputer: - try: - async for result in operate_environment( - **this_iteration.query.args, - user=user, - conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator), - location_data=location, - previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None, - send_status_func=send_status_func, - query_images=query_images, - agent=agent, - query_files=query_files, - cancellation_event=cancellation_event, - interrupt_queue=interrupt_queue, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - elif isinstance(result, OperatorRun): - operator_results = result - this_iteration.operatorContext = operator_results - # Add webpages visited while operating browser to references - if result.webpages: - if not online_results.get(this_iteration.query): - online_results[this_iteration.query] = {"webpages": result.webpages} - elif not online_results[this_iteration.query].get("webpages"): - online_results[this_iteration.query]["webpages"] = result.webpages - else: - online_results[this_iteration.query]["webpages"] += result.webpages - this_iteration.onlineContext = online_results - except Exception as e: - this_iteration.warning = f"Error operating browser: {e}" - logger.error(this_iteration.warning, exc_info=True) - - elif this_iteration.query.name == ConversationCommand.ViewFile: - try: - async for result in view_file_content( - **this_iteration.query.args, - user=user, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - if this_iteration.context is None: - this_iteration.context = [] - document_results: List[Dict[str, str]] = result # type: ignore - this_iteration.context += document_results - async for result in send_status_func(f"**Viewed file**: {this_iteration.query.args['path']}"): - yield result - except Exception as e: - this_iteration.warning = f"Error viewing file: {e}" - logger.error(this_iteration.warning, exc_info=True) - - elif this_iteration.query.name == ConversationCommand.ListFiles: - try: - async for result in list_files( - **this_iteration.query.args, - user=user, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - if this_iteration.context is None: - this_iteration.context = [] - document_results: List[Dict[str, str]] = [result] # type: ignore - this_iteration.context += document_results - async for result in send_status_func(result["query"]): - yield result - except Exception as e: - this_iteration.warning = f"Error listing files: {e}" - logger.error(this_iteration.warning, exc_info=True) - - elif this_iteration.query.name == ConversationCommand.RegexSearchFiles: - try: - async for result in grep_files( - **this_iteration.query.args, - user=user, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - if this_iteration.context is None: - this_iteration.context = [] - document_results: List[Dict[str, str]] = [result] # type: ignore - this_iteration.context += document_results - async for result in send_status_func(result["query"]): - yield result - except Exception as e: - this_iteration.warning = f"Error searching with regex: {e}" - logger.error(this_iteration.warning, exc_info=True) - - elif "/" in this_iteration.query.name: - try: - # Identify MCP client to use - server_name, tool_name = this_iteration.query.name.split("/", 1) - mcp_client = next((client for client in mcp_clients if client.name == server_name), None) - if not mcp_client: - raise ValueError(f"Could not find MCP server with name {server_name}") - - # Invoke tool on the identified MCP server - mcp_results = await mcp_client.run_tool(tool_name, this_iteration.query.args) - - # Record tool result in context - if this_iteration.context is None: - this_iteration.context = [] - this_iteration.context += mcp_results - async for result in send_status_func(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"): - yield result - except Exception as e: - this_iteration.warning = f"Error using MCP tool: {e}" - logger.error(this_iteration.warning, exc_info=True) - - else: - # No valid tools. This is our exit condition. - current_iteration = MAX_ITERATIONS - - current_iteration += 1 - - if ( - document_results - or online_results - or code_results - or operator_results - or mcp_results - or this_iteration.warning - ): - results_data = f"\n" - if document_results: - results_data += f"\n\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" - if online_results: - results_data += f"\n\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" - if code_results: - results_data += f"\n\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" - if operator_results: - results_data += ( - f"\n\n{operator_results.response}\n" - ) - if mcp_results: - results_data += f"\n\n{yaml.dump(mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" - if this_iteration.warning: - results_data += f"\n\n{this_iteration.warning}\n" - results_data += f"\n" - - # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) - this_iteration.summarizedResult = results_data - - this_iteration.summarizedResult = ( - this_iteration.summarizedResult - or f"Failed to get results." - ) - previous_iterations.append(this_iteration) - yield this_iteration + this_iteration.summarizedResult = ( + this_iteration.summarizedResult + or f"Failed to get results." + ) + previous_iterations.append(this_iteration) + yield this_iteration # Close MCP client connections for mcp_client in mcp_clients: