From 054ed79fdf8cf3a5549678b1f26cc9fccd4d7fd9 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 13 Dec 2025 21:37:49 -0800 Subject: [PATCH] Allow LLMs to make parallel tool call requests Why -- - The models are now smart enough to usually understand which tools to call in parallel and when. - The LLM can request more work for each call to it, which is usually the slowest step. This speeds up work by reearch agent. Even though each tool is still executed in sequence (for now). --- .../processor/conversation/anthropic/utils.py | 3 +- src/khoj/processor/conversation/utils.py | 78 ++- src/khoj/routers/research.py | 630 +++++++++--------- 3 files changed, 383 insertions(+), 328 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index cf0fe0dc..522f7fad 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -85,8 +85,7 @@ def anthropic_completion_with_backoff( # Cache tool definitions last_tool = model_kwargs["tools"][-1] last_tool["cache_control"] = {"type": "ephemeral"} - # Disable parallel tool call until we add support for it - model_kwargs["tool_choice"] = {"type": "auto", "disable_parallel_tool_use": True} + model_kwargs["tool_choice"] = {"type": "auto"} elif response_schema: tool = create_tool_definition(response_schema) model_kwargs["tools"] = [ diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 5672134f..a2069243 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -245,8 +245,37 @@ def construct_iteration_history( if query_message_content: iteration_history.append(ChatMessageModel(by="you", message=query_message_content)) + # Group iterations: parallel tool calls share the same raw_response (only first has it) + # We need to group them so one assistant message has all tool_use blocks and + # one user message has all tool_results + current_group_raw_response = None + current_group_tool_results = [] + + def flush_group(): + """Output the current group as assistant message + user message with tool results""" + nonlocal current_group_raw_response, current_group_tool_results + if current_group_raw_response and current_group_tool_results: + iteration_history.append( + ChatMessageModel( + by="khoj", + message=current_group_raw_response, + intent=Intent(type="tool_call", query=query), + ) + ) + iteration_history.append( + ChatMessageModel( + by="you", + intent=Intent(type="tool_result"), + message=current_group_tool_results, + ) + ) + current_group_raw_response = None + current_group_tool_results = [] + for iteration in previous_iterations: if not iteration.query or isinstance(iteration.query, str): + # Flush any pending group before adding non-tool message + flush_group() iteration_history.append( ChatMessageModel( by="you", @@ -256,25 +285,36 @@ def construct_iteration_history( ) ) continue - iteration_history += [ - ChatMessageModel( - by="khoj", - message=iteration.raw_response or [iteration.query.__dict__], - intent=Intent(type="tool_call", query=query), - ), - ChatMessageModel( - by="you", - intent=Intent(type="tool_result"), - message=[ - { - "type": "tool_result", - "id": iteration.query.id, - "name": iteration.query.name, - "content": iteration.summarizedResult, - } - ], - ), - ] + + # If this iteration has raw_response, it starts a new group of parallel tool calls + if iteration.raw_response: + # Flush previous group if exists + flush_group() + current_group_raw_response = iteration.raw_response + + # If no raw_response and no current group, create a fallback single-tool response + elif not current_group_raw_response: + current_group_raw_response = [ + { + "type": "tool_use", + "id": iteration.query.id, + "name": iteration.query.name, + "input": iteration.query.args, + } + ] + + # Add tool result to current group + current_group_tool_results.append( + { + "type": "tool_result", + "id": iteration.query.id, + "name": iteration.query.name, + "content": iteration.summarizedResult, + } + ) + + # Flush any remaining group + flush_group() return iteration_history diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index fe90f41e..d9ed7901 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -208,32 +208,41 @@ async def apick_next_tool( return try: - # Try parse the response as function call response to infer next tool to use. - # TODO: Handle multiple tool calls. + # Try parse the response as function call response to infer next tools to use. response_text = response.text - parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0] + parsed_responses = [ToolCall(**item) for item in load_complex_json(response_text)] except Exception: # Otherwise assume the model has decided to end the research run and respond to the user. - parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None) + parsed_responses = [ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None)] - # If we have a valid response, extract the tool and query. - warning = None - logger.info(f"Response for determining relevant tools: {parsed_response.name}({parsed_response.args})") - - # Detect selection of previously used query, tool combination. + # Detect selection of previously used query, tool combinations. previous_tool_query_combinations = { (i.query.name, dict_to_tuple(i.query.args)) for i in previous_iterations if i.warning is None and isinstance(i.query, ToolCall) } - if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations: - warning = f"Repeated tool, query combination detected. You've already called {parsed_response.name} with args: {parsed_response.args}. Try something different." - # Only send client status updates if we'll execute this iteration and model has thoughts to share. - elif send_status_func and not is_none_or_empty(response.thought): + + # Send status update with model's thoughts if available + if send_status_func and not is_none_or_empty(response.thought): async for event in send_status_func(response.thought): yield {ChatEvent.STATUS: event} - yield ResearchIteration(query=parsed_response, warning=warning, raw_response=response.raw_content) + # Yield a ResearchIteration for each tool call to enable parallel execution + for idx, parsed_response in enumerate(parsed_responses): + warning = None + logger.info( + f"Response for determining relevant tools ({idx + 1}/{len(parsed_responses)}): {parsed_response.name}({parsed_response.args})" + ) + + if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations: + warning = f"Repeated tool, query combination detected. You've already called {parsed_response.name} with args: {parsed_response.args}. Try something different." + + # Include raw_response only for the first tool call to avoid duplication in history + yield ResearchIteration( + query=parsed_response, + warning=warning, + raw_response=response.raw_content if idx == 0 else None, + ) async def research( @@ -296,13 +305,8 @@ async def research( async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"): yield result - online_results: Dict = dict() - code_results: Dict = dict() - document_results: List[Dict[str, str]] = [] - operator_results: OperatorRun = None - mcp_results: List = [] - this_iteration = ResearchIteration(query=query) - + # Collect all tool calls from apick_next_tool + iterations_to_process: List[ResearchIteration] = [] async for result in apick_next_tool( query, research_conversation_history, @@ -324,307 +328,319 @@ async def research( if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] elif isinstance(result, ResearchIteration): - this_iteration = result - yield this_iteration + iterations_to_process.append(result) + yield result - # Skip running iteration if warning present in iteration - if this_iteration.warning: - logger.warning(f"Research mode: {this_iteration.warning}.") + # Process all tool calls from this planning step + for this_iteration in iterations_to_process: + online_results: Dict = dict() + code_results: Dict = dict() + document_results: List[Dict[str, str]] = [] + operator_results: OperatorRun = None + mcp_results: List = [] - # Terminate research if selected text tool or query, tool not set for next iteration - elif ( - not this_iteration.query - or isinstance(this_iteration.query, str) - or this_iteration.query.name == ConversationCommand.Text - ): - current_iteration = MAX_ITERATIONS + # Skip running iteration if warning present in iteration + if this_iteration.warning: + logger.warning(f"Research mode: {this_iteration.warning}.") - elif this_iteration.query.name == ConversationCommand.SemanticSearchFiles: - this_iteration.context = [] - document_results = [] - previous_inferred_queries = { - c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context - } - async for result in search_documents( - **this_iteration.query.args, - n=max_document_searches, - d=None, - user=user, - chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SemanticSearchFiles), - conversation_id=conversation_id, - conversation_commands=[ConversationCommand.Notes], - location_data=location, - send_status_func=send_status_func, - query_images=query_images, - query_files=query_files, - previous_inferred_queries=previous_inferred_queries, - agent=agent, - tracer=tracer, + # Terminate research if selected text tool or query, tool not set for next iteration + elif ( + not this_iteration.query + or isinstance(this_iteration.query, str) + or this_iteration.query.name == ConversationCommand.Text ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - elif isinstance(result, tuple): - document_results = result[0] - this_iteration.context += document_results + current_iteration = MAX_ITERATIONS - if not is_none_or_empty(document_results): + elif this_iteration.query.name == ConversationCommand.SemanticSearchFiles: + this_iteration.context = [] + document_results = [] + previous_inferred_queries = { + c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context + } + async for result in search_documents( + **this_iteration.query.args, + n=max_document_searches, + d=None, + user=user, + chat_history=construct_tool_chat_history( + previous_iterations, ConversationCommand.SemanticSearchFiles + ), + conversation_id=conversation_id, + conversation_commands=[ConversationCommand.Notes], + location_data=location, + send_status_func=send_status_func, + query_images=query_images, + query_files=query_files, + previous_inferred_queries=previous_inferred_queries, + agent=agent, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + elif isinstance(result, tuple): + document_results = result[0] + this_iteration.context += document_results + + if not is_none_or_empty(document_results): + try: + distinct_files = {d["file"] for d in document_results} + distinct_headings = set( + [d["compiled"].split("\n")[0] for d in document_results if "compiled" in d] + ) + # Strip only leading # from headings + headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "") + async for result in send_status_func( + f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}" + ): + yield result + except Exception as e: + this_iteration.warning = f"Error extracting document references: {e}" + logger.error(this_iteration.warning, exc_info=True) + else: + this_iteration.warning = "No matching document references found" + + elif this_iteration.query.name == ConversationCommand.SearchWeb: + previous_subqueries = { + subquery + for iteration in previous_iterations + if iteration.onlineContext + for subquery in iteration.onlineContext.keys() + } try: - distinct_files = {d["file"] for d in document_results} - distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d]) - # Strip only leading # from headings - headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "") - async for result in send_status_func( - f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}" + async for result in search_online( + **this_iteration.query.args, + conversation_history=construct_tool_chat_history( + previous_iterations, ConversationCommand.SearchWeb + ), + location=location, + user=user, + send_status_func=send_status_func, + custom_filters=[], + max_online_searches=max_online_searches, + max_webpages_to_read=0, + query_images=query_images, + previous_subqueries=previous_subqueries, + agent=agent, + tracer=tracer, ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + elif is_none_or_empty(result): + this_iteration.warning = "Detected previously run online search queries. Skipping iteration. Try something different." + else: + online_results: Dict[str, Dict] = result # type: ignore + this_iteration.onlineContext = online_results + except Exception as e: + this_iteration.warning = f"Error searching online: {e}" + logger.error(this_iteration.warning, exc_info=True) + + elif this_iteration.query.name == ConversationCommand.ReadWebpage: + try: + async for result in read_webpages_content( + **this_iteration.query.args, + user=user, + send_status_func=send_status_func, + agent=agent, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + direct_web_pages: Dict[str, Dict] = result # type: ignore + + webpages = [] + for web_query in direct_web_pages: + if online_results.get(web_query): + online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"] + else: + online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]} + + for webpage in direct_web_pages[web_query]["webpages"]: + webpages.append(webpage["link"]) + this_iteration.onlineContext = online_results + except Exception as e: + this_iteration.warning = f"Error reading webpages: {e}" + logger.error(this_iteration.warning, exc_info=True) + + elif this_iteration.query.name == ConversationCommand.PythonCoder: + try: + async for result in run_code( + **this_iteration.query.args, + conversation_history=construct_tool_chat_history( + previous_iterations, ConversationCommand.PythonCoder + ), + context="", + location_data=location, + user=user, + send_status_func=send_status_func, + query_images=query_images, + query_files=query_files, + agent=agent, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + code_results: Dict[str, Dict] = result # type: ignore + this_iteration.codeContext = code_results + async for result in send_status_func(f"**Ran code snippets**: {len(this_iteration.codeContext)}"): + yield result + except (ValueError, TypeError) as e: + this_iteration.warning = f"Error running code: {e}" + logger.warning(this_iteration.warning, exc_info=True) + + elif this_iteration.query.name == ConversationCommand.OperateComputer: + try: + async for result in operate_environment( + **this_iteration.query.args, + user=user, + conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator), + location_data=location, + previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None, + send_status_func=send_status_func, + query_images=query_images, + agent=agent, + query_files=query_files, + cancellation_event=cancellation_event, + interrupt_queue=interrupt_queue, + tracer=tracer, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + elif isinstance(result, OperatorRun): + operator_results = result + this_iteration.operatorContext = operator_results + # Add webpages visited while operating browser to references + if result.webpages: + if not online_results.get(this_iteration.query): + online_results[this_iteration.query] = {"webpages": result.webpages} + elif not online_results[this_iteration.query].get("webpages"): + online_results[this_iteration.query]["webpages"] = result.webpages + else: + online_results[this_iteration.query]["webpages"] += result.webpages + this_iteration.onlineContext = online_results + except Exception as e: + this_iteration.warning = f"Error operating browser: {e}" + logger.error(this_iteration.warning, exc_info=True) + + elif this_iteration.query.name == ConversationCommand.ViewFile: + try: + async for result in view_file_content( + **this_iteration.query.args, + user=user, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + if this_iteration.context is None: + this_iteration.context = [] + document_results: List[Dict[str, str]] = result # type: ignore + this_iteration.context += document_results + async for result in send_status_func(f"**Viewed file**: {this_iteration.query.args['path']}"): yield result except Exception as e: - this_iteration.warning = f"Error extracting document references: {e}" + this_iteration.warning = f"Error viewing file: {e}" logger.error(this_iteration.warning, exc_info=True) + + elif this_iteration.query.name == ConversationCommand.ListFiles: + try: + async for result in list_files( + **this_iteration.query.args, + user=user, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + if this_iteration.context is None: + this_iteration.context = [] + document_results: List[Dict[str, str]] = [result] # type: ignore + this_iteration.context += document_results + async for result in send_status_func(result["query"]): + yield result + except Exception as e: + this_iteration.warning = f"Error listing files: {e}" + logger.error(this_iteration.warning, exc_info=True) + + elif this_iteration.query.name == ConversationCommand.RegexSearchFiles: + try: + async for result in grep_files( + **this_iteration.query.args, + user=user, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + if this_iteration.context is None: + this_iteration.context = [] + document_results: List[Dict[str, str]] = [result] # type: ignore + this_iteration.context += document_results + async for result in send_status_func(result["query"]): + yield result + except Exception as e: + this_iteration.warning = f"Error searching with regex: {e}" + logger.error(this_iteration.warning, exc_info=True) + + elif "/" in this_iteration.query.name: + try: + # Identify MCP client to use + server_name, tool_name = this_iteration.query.name.split("/", 1) + mcp_client = next((client for client in mcp_clients if client.name == server_name), None) + if not mcp_client: + raise ValueError(f"Could not find MCP server with name {server_name}") + + # Invoke tool on the identified MCP server + mcp_results = await mcp_client.run_tool(tool_name, this_iteration.query.args) + + # Record tool result in context + if this_iteration.context is None: + this_iteration.context = [] + this_iteration.context += mcp_results + async for result in send_status_func(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"): + yield result + except Exception as e: + this_iteration.warning = f"Error using MCP tool: {e}" + logger.error(this_iteration.warning, exc_info=True) + else: - this_iteration.warning = "No matching document references found" + # No valid tools. This is our exit condition. + current_iteration = MAX_ITERATIONS - elif this_iteration.query.name == ConversationCommand.SearchWeb: - previous_subqueries = { - subquery - for iteration in previous_iterations - if iteration.onlineContext - for subquery in iteration.onlineContext.keys() - } - try: - async for result in search_online( - **this_iteration.query.args, - conversation_history=construct_tool_chat_history( - previous_iterations, ConversationCommand.SearchWeb - ), - location=location, - user=user, - send_status_func=send_status_func, - custom_filters=[], - max_online_searches=max_online_searches, - max_webpages_to_read=0, - query_images=query_images, - previous_subqueries=previous_subqueries, - agent=agent, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - elif is_none_or_empty(result): - this_iteration.warning = "Detected previously run online search queries. Skipping iteration. Try something different." - else: - online_results: Dict[str, Dict] = result # type: ignore - this_iteration.onlineContext = online_results - except Exception as e: - this_iteration.warning = f"Error searching online: {e}" - logger.error(this_iteration.warning, exc_info=True) + current_iteration += 1 - elif this_iteration.query.name == ConversationCommand.ReadWebpage: - try: - async for result in read_webpages_content( - **this_iteration.query.args, - user=user, - send_status_func=send_status_func, - agent=agent, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - direct_web_pages: Dict[str, Dict] = result # type: ignore + if ( + document_results + or online_results + or code_results + or operator_results + or mcp_results + or this_iteration.warning + ): + results_data = f"\n" + if document_results: + results_data += f"\n\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + if online_results: + results_data += f"\n\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + if code_results: + results_data += f"\n\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + if operator_results: + results_data += ( + f"\n\n{operator_results.response}\n" + ) + if mcp_results: + results_data += f"\n\n{yaml.dump(mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + if this_iteration.warning: + results_data += f"\n\n{this_iteration.warning}\n" + results_data += f"\n" - webpages = [] - for web_query in direct_web_pages: - if online_results.get(web_query): - online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"] - else: - online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]} + # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) + this_iteration.summarizedResult = results_data - for webpage in direct_web_pages[web_query]["webpages"]: - webpages.append(webpage["link"]) - this_iteration.onlineContext = online_results - except Exception as e: - this_iteration.warning = f"Error reading webpages: {e}" - logger.error(this_iteration.warning, exc_info=True) - - elif this_iteration.query.name == ConversationCommand.PythonCoder: - try: - async for result in run_code( - **this_iteration.query.args, - conversation_history=construct_tool_chat_history( - previous_iterations, ConversationCommand.PythonCoder - ), - context="", - location_data=location, - user=user, - send_status_func=send_status_func, - query_images=query_images, - query_files=query_files, - agent=agent, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - code_results: Dict[str, Dict] = result # type: ignore - this_iteration.codeContext = code_results - async for result in send_status_func(f"**Ran code snippets**: {len(this_iteration.codeContext)}"): - yield result - except (ValueError, TypeError) as e: - this_iteration.warning = f"Error running code: {e}" - logger.warning(this_iteration.warning, exc_info=True) - - elif this_iteration.query.name == ConversationCommand.OperateComputer: - try: - async for result in operate_environment( - **this_iteration.query.args, - user=user, - conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator), - location_data=location, - previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None, - send_status_func=send_status_func, - query_images=query_images, - agent=agent, - query_files=query_files, - cancellation_event=cancellation_event, - interrupt_queue=interrupt_queue, - tracer=tracer, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - elif isinstance(result, OperatorRun): - operator_results = result - this_iteration.operatorContext = operator_results - # Add webpages visited while operating browser to references - if result.webpages: - if not online_results.get(this_iteration.query): - online_results[this_iteration.query] = {"webpages": result.webpages} - elif not online_results[this_iteration.query].get("webpages"): - online_results[this_iteration.query]["webpages"] = result.webpages - else: - online_results[this_iteration.query]["webpages"] += result.webpages - this_iteration.onlineContext = online_results - except Exception as e: - this_iteration.warning = f"Error operating browser: {e}" - logger.error(this_iteration.warning, exc_info=True) - - elif this_iteration.query.name == ConversationCommand.ViewFile: - try: - async for result in view_file_content( - **this_iteration.query.args, - user=user, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - if this_iteration.context is None: - this_iteration.context = [] - document_results: List[Dict[str, str]] = result # type: ignore - this_iteration.context += document_results - async for result in send_status_func(f"**Viewed file**: {this_iteration.query.args['path']}"): - yield result - except Exception as e: - this_iteration.warning = f"Error viewing file: {e}" - logger.error(this_iteration.warning, exc_info=True) - - elif this_iteration.query.name == ConversationCommand.ListFiles: - try: - async for result in list_files( - **this_iteration.query.args, - user=user, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - if this_iteration.context is None: - this_iteration.context = [] - document_results: List[Dict[str, str]] = [result] # type: ignore - this_iteration.context += document_results - async for result in send_status_func(result["query"]): - yield result - except Exception as e: - this_iteration.warning = f"Error listing files: {e}" - logger.error(this_iteration.warning, exc_info=True) - - elif this_iteration.query.name == ConversationCommand.RegexSearchFiles: - try: - async for result in grep_files( - **this_iteration.query.args, - user=user, - ): - if isinstance(result, dict) and ChatEvent.STATUS in result: - yield result[ChatEvent.STATUS] - else: - if this_iteration.context is None: - this_iteration.context = [] - document_results: List[Dict[str, str]] = [result] # type: ignore - this_iteration.context += document_results - async for result in send_status_func(result["query"]): - yield result - except Exception as e: - this_iteration.warning = f"Error searching with regex: {e}" - logger.error(this_iteration.warning, exc_info=True) - - elif "/" in this_iteration.query.name: - try: - # Identify MCP client to use - server_name, tool_name = this_iteration.query.name.split("/", 1) - mcp_client = next((client for client in mcp_clients if client.name == server_name), None) - if not mcp_client: - raise ValueError(f"Could not find MCP server with name {server_name}") - - # Invoke tool on the identified MCP server - mcp_results = await mcp_client.run_tool(tool_name, this_iteration.query.args) - - # Record tool result in context - if this_iteration.context is None: - this_iteration.context = [] - this_iteration.context += mcp_results - async for result in send_status_func(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"): - yield result - except Exception as e: - this_iteration.warning = f"Error using MCP tool: {e}" - logger.error(this_iteration.warning, exc_info=True) - - else: - # No valid tools. This is our exit condition. - current_iteration = MAX_ITERATIONS - - current_iteration += 1 - - if ( - document_results - or online_results - or code_results - or operator_results - or mcp_results - or this_iteration.warning - ): - results_data = f"\n" - if document_results: - results_data += f"\n\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" - if online_results: - results_data += f"\n\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" - if code_results: - results_data += f"\n\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" - if operator_results: - results_data += ( - f"\n\n{operator_results.response}\n" - ) - if mcp_results: - results_data += f"\n\n{yaml.dump(mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" - if this_iteration.warning: - results_data += f"\n\n{this_iteration.warning}\n" - results_data += f"\n" - - # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) - this_iteration.summarizedResult = results_data - - this_iteration.summarizedResult = ( - this_iteration.summarizedResult - or f"Failed to get results." - ) - previous_iterations.append(this_iteration) - yield this_iteration + this_iteration.summarizedResult = ( + this_iteration.summarizedResult + or f"Failed to get results." + ) + previous_iterations.append(this_iteration) + yield this_iteration # Close MCP client connections for mcp_client in mcp_clients: