Allow LLMs to make parallel tool call requests

Why
--
- The models are now smart enough to usually understand which tools to
  call in parallel and when.

- The LLM can request more work for each call to it, which is usually
  the slowest step. This speeds up work by reearch agent. Even though
  each tool is still executed in sequence (for now).
This commit is contained in:
Debanjum
2025-12-13 21:37:49 -08:00
parent f4c519a9d0
commit 054ed79fdf
3 changed files with 383 additions and 328 deletions

View File

@@ -85,8 +85,7 @@ def anthropic_completion_with_backoff(
# Cache tool definitions # Cache tool definitions
last_tool = model_kwargs["tools"][-1] last_tool = model_kwargs["tools"][-1]
last_tool["cache_control"] = {"type": "ephemeral"} last_tool["cache_control"] = {"type": "ephemeral"}
# Disable parallel tool call until we add support for it model_kwargs["tool_choice"] = {"type": "auto"}
model_kwargs["tool_choice"] = {"type": "auto", "disable_parallel_tool_use": True}
elif response_schema: elif response_schema:
tool = create_tool_definition(response_schema) tool = create_tool_definition(response_schema)
model_kwargs["tools"] = [ model_kwargs["tools"] = [

View File

@@ -245,8 +245,37 @@ def construct_iteration_history(
if query_message_content: if query_message_content:
iteration_history.append(ChatMessageModel(by="you", message=query_message_content)) iteration_history.append(ChatMessageModel(by="you", message=query_message_content))
# Group iterations: parallel tool calls share the same raw_response (only first has it)
# We need to group them so one assistant message has all tool_use blocks and
# one user message has all tool_results
current_group_raw_response = None
current_group_tool_results = []
def flush_group():
"""Output the current group as assistant message + user message with tool results"""
nonlocal current_group_raw_response, current_group_tool_results
if current_group_raw_response and current_group_tool_results:
iteration_history.append(
ChatMessageModel(
by="khoj",
message=current_group_raw_response,
intent=Intent(type="tool_call", query=query),
)
)
iteration_history.append(
ChatMessageModel(
by="you",
intent=Intent(type="tool_result"),
message=current_group_tool_results,
)
)
current_group_raw_response = None
current_group_tool_results = []
for iteration in previous_iterations: for iteration in previous_iterations:
if not iteration.query or isinstance(iteration.query, str): if not iteration.query or isinstance(iteration.query, str):
# Flush any pending group before adding non-tool message
flush_group()
iteration_history.append( iteration_history.append(
ChatMessageModel( ChatMessageModel(
by="you", by="you",
@@ -256,25 +285,36 @@ def construct_iteration_history(
) )
) )
continue continue
iteration_history += [
ChatMessageModel( # If this iteration has raw_response, it starts a new group of parallel tool calls
by="khoj", if iteration.raw_response:
message=iteration.raw_response or [iteration.query.__dict__], # Flush previous group if exists
intent=Intent(type="tool_call", query=query), flush_group()
), current_group_raw_response = iteration.raw_response
ChatMessageModel(
by="you", # If no raw_response and no current group, create a fallback single-tool response
intent=Intent(type="tool_result"), elif not current_group_raw_response:
message=[ current_group_raw_response = [
{ {
"type": "tool_result", "type": "tool_use",
"id": iteration.query.id, "id": iteration.query.id,
"name": iteration.query.name, "name": iteration.query.name,
"content": iteration.summarizedResult, "input": iteration.query.args,
} }
], ]
),
] # Add tool result to current group
current_group_tool_results.append(
{
"type": "tool_result",
"id": iteration.query.id,
"name": iteration.query.name,
"content": iteration.summarizedResult,
}
)
# Flush any remaining group
flush_group()
return iteration_history return iteration_history

View File

@@ -208,32 +208,41 @@ async def apick_next_tool(
return return
try: try:
# Try parse the response as function call response to infer next tool to use. # Try parse the response as function call response to infer next tools to use.
# TODO: Handle multiple tool calls.
response_text = response.text response_text = response.text
parsed_response = [ToolCall(**item) for item in load_complex_json(response_text)][0] parsed_responses = [ToolCall(**item) for item in load_complex_json(response_text)]
except Exception: except Exception:
# Otherwise assume the model has decided to end the research run and respond to the user. # Otherwise assume the model has decided to end the research run and respond to the user.
parsed_response = ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None) parsed_responses = [ToolCall(name=ConversationCommand.Text, args={"response": response_text}, id=None)]
# If we have a valid response, extract the tool and query. # Detect selection of previously used query, tool combinations.
warning = None
logger.info(f"Response for determining relevant tools: {parsed_response.name}({parsed_response.args})")
# Detect selection of previously used query, tool combination.
previous_tool_query_combinations = { previous_tool_query_combinations = {
(i.query.name, dict_to_tuple(i.query.args)) (i.query.name, dict_to_tuple(i.query.args))
for i in previous_iterations for i in previous_iterations
if i.warning is None and isinstance(i.query, ToolCall) if i.warning is None and isinstance(i.query, ToolCall)
} }
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. You've already called {parsed_response.name} with args: {parsed_response.args}. Try something different." # Send status update with model's thoughts if available
# Only send client status updates if we'll execute this iteration and model has thoughts to share. if send_status_func and not is_none_or_empty(response.thought):
elif send_status_func and not is_none_or_empty(response.thought):
async for event in send_status_func(response.thought): async for event in send_status_func(response.thought):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
yield ResearchIteration(query=parsed_response, warning=warning, raw_response=response.raw_content) # Yield a ResearchIteration for each tool call to enable parallel execution
for idx, parsed_response in enumerate(parsed_responses):
warning = None
logger.info(
f"Response for determining relevant tools ({idx + 1}/{len(parsed_responses)}): {parsed_response.name}({parsed_response.args})"
)
if (parsed_response.name, dict_to_tuple(parsed_response.args)) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. You've already called {parsed_response.name} with args: {parsed_response.args}. Try something different."
# Include raw_response only for the first tool call to avoid duplication in history
yield ResearchIteration(
query=parsed_response,
warning=warning,
raw_response=response.raw_content if idx == 0 else None,
)
async def research( async def research(
@@ -296,13 +305,8 @@ async def research(
async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"): async for result in send_status_func(f"**Incorporate New Instruction**: {interrupt_query}"):
yield result yield result
online_results: Dict = dict() # Collect all tool calls from apick_next_tool
code_results: Dict = dict() iterations_to_process: List[ResearchIteration] = []
document_results: List[Dict[str, str]] = []
operator_results: OperatorRun = None
mcp_results: List = []
this_iteration = ResearchIteration(query=query)
async for result in apick_next_tool( async for result in apick_next_tool(
query, query,
research_conversation_history, research_conversation_history,
@@ -324,307 +328,319 @@ async def research(
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
elif isinstance(result, ResearchIteration): elif isinstance(result, ResearchIteration):
this_iteration = result iterations_to_process.append(result)
yield this_iteration yield result
# Skip running iteration if warning present in iteration # Process all tool calls from this planning step
if this_iteration.warning: for this_iteration in iterations_to_process:
logger.warning(f"Research mode: {this_iteration.warning}.") online_results: Dict = dict()
code_results: Dict = dict()
document_results: List[Dict[str, str]] = []
operator_results: OperatorRun = None
mcp_results: List = []
# Terminate research if selected text tool or query, tool not set for next iteration # Skip running iteration if warning present in iteration
elif ( if this_iteration.warning:
not this_iteration.query logger.warning(f"Research mode: {this_iteration.warning}.")
or isinstance(this_iteration.query, str)
or this_iteration.query.name == ConversationCommand.Text
):
current_iteration = MAX_ITERATIONS
elif this_iteration.query.name == ConversationCommand.SemanticSearchFiles: # Terminate research if selected text tool or query, tool not set for next iteration
this_iteration.context = [] elif (
document_results = [] not this_iteration.query
previous_inferred_queries = { or isinstance(this_iteration.query, str)
c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context or this_iteration.query.name == ConversationCommand.Text
}
async for result in search_documents(
**this_iteration.query.args,
n=max_document_searches,
d=None,
user=user,
chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SemanticSearchFiles),
conversation_id=conversation_id,
conversation_commands=[ConversationCommand.Notes],
location_data=location,
send_status_func=send_status_func,
query_images=query_images,
query_files=query_files,
previous_inferred_queries=previous_inferred_queries,
agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: current_iteration = MAX_ITERATIONS
yield result[ChatEvent.STATUS]
elif isinstance(result, tuple):
document_results = result[0]
this_iteration.context += document_results
if not is_none_or_empty(document_results): elif this_iteration.query.name == ConversationCommand.SemanticSearchFiles:
this_iteration.context = []
document_results = []
previous_inferred_queries = {
c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
}
async for result in search_documents(
**this_iteration.query.args,
n=max_document_searches,
d=None,
user=user,
chat_history=construct_tool_chat_history(
previous_iterations, ConversationCommand.SemanticSearchFiles
),
conversation_id=conversation_id,
conversation_commands=[ConversationCommand.Notes],
location_data=location,
send_status_func=send_status_func,
query_images=query_images,
query_files=query_files,
previous_inferred_queries=previous_inferred_queries,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, tuple):
document_results = result[0]
this_iteration.context += document_results
if not is_none_or_empty(document_results):
try:
distinct_files = {d["file"] for d in document_results}
distinct_headings = set(
[d["compiled"].split("\n")[0] for d in document_results if "compiled" in d]
)
# Strip only leading # from headings
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
async for result in send_status_func(
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
):
yield result
except Exception as e:
this_iteration.warning = f"Error extracting document references: {e}"
logger.error(this_iteration.warning, exc_info=True)
else:
this_iteration.warning = "No matching document references found"
elif this_iteration.query.name == ConversationCommand.SearchWeb:
previous_subqueries = {
subquery
for iteration in previous_iterations
if iteration.onlineContext
for subquery in iteration.onlineContext.keys()
}
try: try:
distinct_files = {d["file"] for d in document_results} async for result in search_online(
distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d]) **this_iteration.query.args,
# Strip only leading # from headings conversation_history=construct_tool_chat_history(
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "") previous_iterations, ConversationCommand.SearchWeb
async for result in send_status_func( ),
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}" location=location,
user=user,
send_status_func=send_status_func,
custom_filters=[],
max_online_searches=max_online_searches,
max_webpages_to_read=0,
query_images=query_images,
previous_subqueries=previous_subqueries,
agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif is_none_or_empty(result):
this_iteration.warning = "Detected previously run online search queries. Skipping iteration. Try something different."
else:
online_results: Dict[str, Dict] = result # type: ignore
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error searching online: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ReadWebpage:
try:
async for result in read_webpages_content(
**this_iteration.query.args,
user=user,
send_status_func=send_status_func,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
direct_web_pages: Dict[str, Dict] = result # type: ignore
webpages = []
for web_query in direct_web_pages:
if online_results.get(web_query):
online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"]
else:
online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]}
for webpage in direct_web_pages[web_query]["webpages"]:
webpages.append(webpage["link"])
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error reading webpages: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.PythonCoder:
try:
async for result in run_code(
**this_iteration.query.args,
conversation_history=construct_tool_chat_history(
previous_iterations, ConversationCommand.PythonCoder
),
context="",
location_data=location,
user=user,
send_status_func=send_status_func,
query_images=query_images,
query_files=query_files,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
code_results: Dict[str, Dict] = result # type: ignore
this_iteration.codeContext = code_results
async for result in send_status_func(f"**Ran code snippets**: {len(this_iteration.codeContext)}"):
yield result
except (ValueError, TypeError) as e:
this_iteration.warning = f"Error running code: {e}"
logger.warning(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.OperateComputer:
try:
async for result in operate_environment(
**this_iteration.query.args,
user=user,
conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
location_data=location,
previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
send_status_func=send_status_func,
query_images=query_images,
agent=agent,
query_files=query_files,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, OperatorRun):
operator_results = result
this_iteration.operatorContext = operator_results
# Add webpages visited while operating browser to references
if result.webpages:
if not online_results.get(this_iteration.query):
online_results[this_iteration.query] = {"webpages": result.webpages}
elif not online_results[this_iteration.query].get("webpages"):
online_results[this_iteration.query]["webpages"] = result.webpages
else:
online_results[this_iteration.query]["webpages"] += result.webpages
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error operating browser: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ViewFile:
try:
async for result in view_file_content(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = result # type: ignore
this_iteration.context += document_results
async for result in send_status_func(f"**Viewed file**: {this_iteration.query.args['path']}"):
yield result yield result
except Exception as e: except Exception as e:
this_iteration.warning = f"Error extracting document references: {e}" this_iteration.warning = f"Error viewing file: {e}"
logger.error(this_iteration.warning, exc_info=True) logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ListFiles:
try:
async for result in list_files(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = [result] # type: ignore
this_iteration.context += document_results
async for result in send_status_func(result["query"]):
yield result
except Exception as e:
this_iteration.warning = f"Error listing files: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.RegexSearchFiles:
try:
async for result in grep_files(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = [result] # type: ignore
this_iteration.context += document_results
async for result in send_status_func(result["query"]):
yield result
except Exception as e:
this_iteration.warning = f"Error searching with regex: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif "/" in this_iteration.query.name:
try:
# Identify MCP client to use
server_name, tool_name = this_iteration.query.name.split("/", 1)
mcp_client = next((client for client in mcp_clients if client.name == server_name), None)
if not mcp_client:
raise ValueError(f"Could not find MCP server with name {server_name}")
# Invoke tool on the identified MCP server
mcp_results = await mcp_client.run_tool(tool_name, this_iteration.query.args)
# Record tool result in context
if this_iteration.context is None:
this_iteration.context = []
this_iteration.context += mcp_results
async for result in send_status_func(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"):
yield result
except Exception as e:
this_iteration.warning = f"Error using MCP tool: {e}"
logger.error(this_iteration.warning, exc_info=True)
else: else:
this_iteration.warning = "No matching document references found" # No valid tools. This is our exit condition.
current_iteration = MAX_ITERATIONS
elif this_iteration.query.name == ConversationCommand.SearchWeb: current_iteration += 1
previous_subqueries = {
subquery
for iteration in previous_iterations
if iteration.onlineContext
for subquery in iteration.onlineContext.keys()
}
try:
async for result in search_online(
**this_iteration.query.args,
conversation_history=construct_tool_chat_history(
previous_iterations, ConversationCommand.SearchWeb
),
location=location,
user=user,
send_status_func=send_status_func,
custom_filters=[],
max_online_searches=max_online_searches,
max_webpages_to_read=0,
query_images=query_images,
previous_subqueries=previous_subqueries,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif is_none_or_empty(result):
this_iteration.warning = "Detected previously run online search queries. Skipping iteration. Try something different."
else:
online_results: Dict[str, Dict] = result # type: ignore
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error searching online: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ReadWebpage: if (
try: document_results
async for result in read_webpages_content( or online_results
**this_iteration.query.args, or code_results
user=user, or operator_results
send_status_func=send_status_func, or mcp_results
agent=agent, or this_iteration.warning
tracer=tracer, ):
): results_data = f"\n<iteration_{current_iteration}_results>"
if isinstance(result, dict) and ChatEvent.STATUS in result: if document_results:
yield result[ChatEvent.STATUS] results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
else: if online_results:
direct_web_pages: Dict[str, Dict] = result # type: ignore results_data += f"\n<online_results>\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
if code_results:
results_data += f"\n<code_results>\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
if operator_results:
results_data += (
f"\n<browser_operator_results>\n{operator_results.response}\n</browser_operator_results>"
)
if mcp_results:
results_data += f"\n<mcp_tool_results>\n{yaml.dump(mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</mcp_tool_results>"
if this_iteration.warning:
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
results_data += f"\n</iteration_{current_iteration}_results>"
webpages = [] # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
for web_query in direct_web_pages: this_iteration.summarizedResult = results_data
if online_results.get(web_query):
online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"]
else:
online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]}
for webpage in direct_web_pages[web_query]["webpages"]: this_iteration.summarizedResult = (
webpages.append(webpage["link"]) this_iteration.summarizedResult
this_iteration.onlineContext = online_results or f"<iteration_{current_iteration}_results>Failed to get results.</iteration_{current_iteration}_results>"
except Exception as e: )
this_iteration.warning = f"Error reading webpages: {e}" previous_iterations.append(this_iteration)
logger.error(this_iteration.warning, exc_info=True) yield this_iteration
elif this_iteration.query.name == ConversationCommand.PythonCoder:
try:
async for result in run_code(
**this_iteration.query.args,
conversation_history=construct_tool_chat_history(
previous_iterations, ConversationCommand.PythonCoder
),
context="",
location_data=location,
user=user,
send_status_func=send_status_func,
query_images=query_images,
query_files=query_files,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
code_results: Dict[str, Dict] = result # type: ignore
this_iteration.codeContext = code_results
async for result in send_status_func(f"**Ran code snippets**: {len(this_iteration.codeContext)}"):
yield result
except (ValueError, TypeError) as e:
this_iteration.warning = f"Error running code: {e}"
logger.warning(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.OperateComputer:
try:
async for result in operate_environment(
**this_iteration.query.args,
user=user,
conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
location_data=location,
previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
send_status_func=send_status_func,
query_images=query_images,
agent=agent,
query_files=query_files,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, OperatorRun):
operator_results = result
this_iteration.operatorContext = operator_results
# Add webpages visited while operating browser to references
if result.webpages:
if not online_results.get(this_iteration.query):
online_results[this_iteration.query] = {"webpages": result.webpages}
elif not online_results[this_iteration.query].get("webpages"):
online_results[this_iteration.query]["webpages"] = result.webpages
else:
online_results[this_iteration.query]["webpages"] += result.webpages
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error operating browser: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ViewFile:
try:
async for result in view_file_content(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = result # type: ignore
this_iteration.context += document_results
async for result in send_status_func(f"**Viewed file**: {this_iteration.query.args['path']}"):
yield result
except Exception as e:
this_iteration.warning = f"Error viewing file: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ListFiles:
try:
async for result in list_files(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = [result] # type: ignore
this_iteration.context += document_results
async for result in send_status_func(result["query"]):
yield result
except Exception as e:
this_iteration.warning = f"Error listing files: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.RegexSearchFiles:
try:
async for result in grep_files(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = [result] # type: ignore
this_iteration.context += document_results
async for result in send_status_func(result["query"]):
yield result
except Exception as e:
this_iteration.warning = f"Error searching with regex: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif "/" in this_iteration.query.name:
try:
# Identify MCP client to use
server_name, tool_name = this_iteration.query.name.split("/", 1)
mcp_client = next((client for client in mcp_clients if client.name == server_name), None)
if not mcp_client:
raise ValueError(f"Could not find MCP server with name {server_name}")
# Invoke tool on the identified MCP server
mcp_results = await mcp_client.run_tool(tool_name, this_iteration.query.args)
# Record tool result in context
if this_iteration.context is None:
this_iteration.context = []
this_iteration.context += mcp_results
async for result in send_status_func(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"):
yield result
except Exception as e:
this_iteration.warning = f"Error using MCP tool: {e}"
logger.error(this_iteration.warning, exc_info=True)
else:
# No valid tools. This is our exit condition.
current_iteration = MAX_ITERATIONS
current_iteration += 1
if (
document_results
or online_results
or code_results
or operator_results
or mcp_results
or this_iteration.warning
):
results_data = f"\n<iteration_{current_iteration}_results>"
if document_results:
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
if online_results:
results_data += f"\n<online_results>\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
if code_results:
results_data += f"\n<code_results>\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
if operator_results:
results_data += (
f"\n<browser_operator_results>\n{operator_results.response}\n</browser_operator_results>"
)
if mcp_results:
results_data += f"\n<mcp_tool_results>\n{yaml.dump(mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</mcp_tool_results>"
if this_iteration.warning:
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
results_data += f"\n</iteration_{current_iteration}_results>"
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = results_data
this_iteration.summarizedResult = (
this_iteration.summarizedResult
or f"<iteration_{current_iteration}_results>Failed to get results.</iteration_{current_iteration}_results>"
)
previous_iterations.append(this_iteration)
yield this_iteration
# Close MCP client connections # Close MCP client connections
for mcp_client in mcp_clients: for mcp_client in mcp_clients: