Execute tool calls in parallel to make research iterations faster

This commit is contained in:
Debanjum
2025-12-13 22:55:54 -08:00
parent 054ed79fdf
commit cdcbdf8459

View File

@@ -48,6 +48,261 @@ from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ToolExecutionResult:
"""Result of executing a single tool call"""
def __init__(self):
self.status_messages: List = []
self.document_results: List[Dict[str, str]] = []
self.online_results: Dict = {}
self.code_results: Dict = {}
self.operator_results: OperatorRun = None
self.mcp_results: List = []
self.should_terminate: bool = False
async def execute_tool(
iteration: ResearchIteration,
user: KhojUser,
conversation_id: str,
previous_iterations: List[ResearchIteration],
location: LocationData,
query_images: List[str],
query_files: str,
max_document_searches: int,
max_online_searches: int,
mcp_clients: List[MCPClient],
cancellation_event: Optional[asyncio.Event],
interrupt_queue: Optional[asyncio.Queue],
agent: Agent,
tracer: dict,
) -> ToolExecutionResult:
"""Execute a single tool call and return results. Designed for parallel execution."""
result = ToolExecutionResult()
# Skip if warning present
if iteration.warning:
logger.warning(f"Research mode: {iteration.warning}.")
return result
# Check for termination conditions
if not iteration.query or isinstance(iteration.query, str) or iteration.query.name == ConversationCommand.Text:
result.should_terminate = True
return result
# Create a status collector that captures messages for later batch yield
async def status_collector(message: str):
"""Async generator that collects status messages instead of streaming them."""
# Just collect the message - we'll process it later in the main loop
result.status_messages.append(message)
yield message # Yield to satisfy async generator protocol expected by tool functions
try:
if iteration.query.name == ConversationCommand.SemanticSearchFiles:
iteration.context = []
previous_inferred_queries = {
c["query"] for iter in previous_iterations if iter.context for c in iter.context
}
async for res in search_documents(
**iteration.query.args,
n=max_document_searches,
d=None,
user=user,
chat_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SemanticSearchFiles),
conversation_id=conversation_id,
conversation_commands=[ConversationCommand.Notes],
location_data=location,
send_status_func=status_collector,
query_images=query_images,
query_files=query_files,
previous_inferred_queries=previous_inferred_queries,
agent=agent,
tracer=tracer,
):
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
if isinstance(res, tuple):
result.document_results = res[0]
iteration.context += result.document_results
if not is_none_or_empty(result.document_results):
try:
distinct_files = {d["file"] for d in result.document_results}
distinct_headings = set(
[d["compiled"].split("\n")[0] for d in result.document_results if "compiled" in d]
)
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
async for _ in status_collector(
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
):
pass
except Exception as e:
iteration.warning = f"Error extracting document references: {e}"
logger.error(iteration.warning, exc_info=True)
else:
iteration.warning = "No matching document references found"
elif iteration.query.name == ConversationCommand.SearchWeb:
previous_subqueries = {
subquery for iter in previous_iterations if iter.onlineContext for subquery in iter.onlineContext.keys()
}
async for res in search_online(
**iteration.query.args,
conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.SearchWeb),
location=location,
user=user,
send_status_func=status_collector,
custom_filters=[],
max_online_searches=max_online_searches,
max_webpages_to_read=0,
query_images=query_images,
previous_subqueries=previous_subqueries,
agent=agent,
tracer=tracer,
):
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
if is_none_or_empty(res):
iteration.warning = (
"Detected previously run online search queries. Skipping iteration. Try something different."
)
elif not (isinstance(res, dict) and ChatEvent.STATUS in res):
result.online_results = res
iteration.onlineContext = result.online_results
elif iteration.query.name == ConversationCommand.ReadWebpage:
async for res in read_webpages_content(
**iteration.query.args,
user=user,
send_status_func=status_collector,
agent=agent,
tracer=tracer,
):
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
direct_web_pages: Dict[str, Dict] = res
for web_query in direct_web_pages:
if result.online_results.get(web_query):
result.online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"]
else:
result.online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]}
iteration.onlineContext = result.online_results
elif iteration.query.name == ConversationCommand.PythonCoder:
async for res in run_code(
**iteration.query.args,
conversation_history=construct_tool_chat_history(previous_iterations, ConversationCommand.PythonCoder),
context="",
location_data=location,
user=user,
send_status_func=status_collector,
query_images=query_images,
query_files=query_files,
agent=agent,
tracer=tracer,
):
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
result.code_results = res
iteration.codeContext = result.code_results
if iteration.codeContext:
async for _ in status_collector(f"**Ran code snippets**: {len(iteration.codeContext)}"):
pass
elif iteration.query.name == ConversationCommand.OperateComputer:
async for res in operate_environment(
**iteration.query.args,
user=user,
conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
location_data=location,
previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
send_status_func=status_collector,
query_images=query_images,
agent=agent,
query_files=query_files,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
tracer=tracer,
):
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
if isinstance(res, OperatorRun):
result.operator_results = res
iteration.operatorContext = result.operator_results
if res.webpages:
if not result.online_results.get(iteration.query):
result.online_results[iteration.query] = {"webpages": res.webpages}
elif not result.online_results[iteration.query].get("webpages"):
result.online_results[iteration.query]["webpages"] = res.webpages
else:
result.online_results[iteration.query]["webpages"] += res.webpages
iteration.onlineContext = result.online_results
elif iteration.query.name == ConversationCommand.ViewFile:
async for res in view_file_content(
**iteration.query.args,
user=user,
):
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
if iteration.context is None:
iteration.context = []
result.document_results = res
iteration.context += result.document_results
async for _ in status_collector(f"**Viewed file**: {iteration.query.args['path']}"):
pass
elif iteration.query.name == ConversationCommand.ListFiles:
async for res in list_files(
**iteration.query.args,
user=user,
):
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
if iteration.context is None:
iteration.context = []
result.document_results = [res]
iteration.context += result.document_results
if result.document_results:
async for _ in status_collector(result.document_results[-1].get("query", "Listed files")):
pass
elif iteration.query.name == ConversationCommand.RegexSearchFiles:
async for res in grep_files(
**iteration.query.args,
user=user,
):
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
if iteration.context is None:
iteration.context = []
result.document_results = [res]
iteration.context += result.document_results
if result.document_results:
async for _ in status_collector(result.document_results[-1].get("query", "Searched files")):
pass
elif "/" in iteration.query.name:
server_name, tool_name = iteration.query.name.split("/", 1)
mcp_client = next((client for client in mcp_clients if client.name == server_name), None)
if not mcp_client:
raise ValueError(f"Could not find MCP server with name {server_name}")
result.mcp_results = await mcp_client.run_tool(tool_name, iteration.query.args)
if iteration.context is None:
iteration.context = []
iteration.context += result.mcp_results
async for _ in status_collector(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"):
pass
else:
result.should_terminate = True
except Exception as e:
tool_name = iteration.query.name if iteration.query else "unknown"
iteration.warning = f"Error executing {tool_name}: {e}"
logger.error(iteration.warning, exc_info=True)
return result
async def apick_next_tool( async def apick_next_tool(
query: str, query: str,
conversation_history: List[ChatMessageModel], conversation_history: List[ChatMessageModel],
@@ -331,316 +586,83 @@ async def research(
iterations_to_process.append(result) iterations_to_process.append(result)
yield result yield result
# Process all tool calls from this planning step if iterations_to_process:
for this_iteration in iterations_to_process: # Create tasks for parallel execution
online_results: Dict = dict() tasks = [
code_results: Dict = dict() execute_tool(
document_results: List[Dict[str, str]] = [] iteration=iteration,
operator_results: OperatorRun = None
mcp_results: List = []
# Skip running iteration if warning present in iteration
if this_iteration.warning:
logger.warning(f"Research mode: {this_iteration.warning}.")
# Terminate research if selected text tool or query, tool not set for next iteration
elif (
not this_iteration.query
or isinstance(this_iteration.query, str)
or this_iteration.query.name == ConversationCommand.Text
):
current_iteration = MAX_ITERATIONS
elif this_iteration.query.name == ConversationCommand.SemanticSearchFiles:
this_iteration.context = []
document_results = []
previous_inferred_queries = {
c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
}
async for result in search_documents(
**this_iteration.query.args,
n=max_document_searches,
d=None,
user=user, user=user,
chat_history=construct_tool_chat_history(
previous_iterations, ConversationCommand.SemanticSearchFiles
),
conversation_id=conversation_id, conversation_id=conversation_id,
conversation_commands=[ConversationCommand.Notes], previous_iterations=previous_iterations,
location_data=location, location=location,
send_status_func=send_status_func,
query_images=query_images, query_images=query_images,
query_files=query_files, query_files=query_files,
previous_inferred_queries=previous_inferred_queries, max_document_searches=max_document_searches,
max_online_searches=max_online_searches,
mcp_clients=mcp_clients,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
)
for iteration in iterations_to_process
]
# Execute all tools in parallel
tool_results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results and yield status messages
for this_iteration, tool_result in zip(iterations_to_process, tool_results):
# Handle exceptions from asyncio.gather
if isinstance(tool_result, Exception):
this_iteration.warning = f"Error executing tool: {tool_result}"
logger.error(this_iteration.warning, exc_info=True)
tool_result = ToolExecutionResult()
# Check for termination
if tool_result.should_terminate:
current_iteration = MAX_ITERATIONS
# Yield all collected status messages through the real send_status_func
for status_msg in tool_result.status_messages:
if send_status_func:
async for status_event in send_status_func(status_msg):
yield status_event
current_iteration += 1
# Build summarized results
if (
tool_result.document_results
or tool_result.online_results
or tool_result.code_results
or tool_result.operator_results
or tool_result.mcp_results
or this_iteration.warning
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: results_data = f"\n<iteration_{current_iteration}_results>"
yield result[ChatEvent.STATUS] if tool_result.document_results:
elif isinstance(result, tuple): results_data += f"\n<document_references>\n{yaml.dump(tool_result.document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
document_results = result[0] if tool_result.online_results:
this_iteration.context += document_results results_data += f"\n<online_results>\n{yaml.dump(tool_result.online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
if tool_result.code_results:
results_data += f"\n<code_results>\n{yaml.dump(truncate_code_context(tool_result.code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
if tool_result.operator_results:
results_data += f"\n<browser_operator_results>\n{tool_result.operator_results.response}\n</browser_operator_results>"
if tool_result.mcp_results:
results_data += f"\n<mcp_tool_results>\n{yaml.dump(tool_result.mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</mcp_tool_results>"
if this_iteration.warning:
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
results_data += f"\n</iteration_{current_iteration}_results>"
if not is_none_or_empty(document_results): this_iteration.summarizedResult = results_data
try:
distinct_files = {d["file"] for d in document_results}
distinct_headings = set(
[d["compiled"].split("\n")[0] for d in document_results if "compiled" in d]
)
# Strip only leading # from headings
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
async for result in send_status_func(
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
):
yield result
except Exception as e:
this_iteration.warning = f"Error extracting document references: {e}"
logger.error(this_iteration.warning, exc_info=True)
else:
this_iteration.warning = "No matching document references found"
elif this_iteration.query.name == ConversationCommand.SearchWeb: this_iteration.summarizedResult = (
previous_subqueries = { this_iteration.summarizedResult
subquery or f"<iteration_{current_iteration}_results>Failed to get results.</iteration_{current_iteration}_results>"
for iteration in previous_iterations )
if iteration.onlineContext previous_iterations.append(this_iteration)
for subquery in iteration.onlineContext.keys() yield this_iteration
}
try:
async for result in search_online(
**this_iteration.query.args,
conversation_history=construct_tool_chat_history(
previous_iterations, ConversationCommand.SearchWeb
),
location=location,
user=user,
send_status_func=send_status_func,
custom_filters=[],
max_online_searches=max_online_searches,
max_webpages_to_read=0,
query_images=query_images,
previous_subqueries=previous_subqueries,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif is_none_or_empty(result):
this_iteration.warning = "Detected previously run online search queries. Skipping iteration. Try something different."
else:
online_results: Dict[str, Dict] = result # type: ignore
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error searching online: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ReadWebpage:
try:
async for result in read_webpages_content(
**this_iteration.query.args,
user=user,
send_status_func=send_status_func,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
direct_web_pages: Dict[str, Dict] = result # type: ignore
webpages = []
for web_query in direct_web_pages:
if online_results.get(web_query):
online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"]
else:
online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]}
for webpage in direct_web_pages[web_query]["webpages"]:
webpages.append(webpage["link"])
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error reading webpages: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.PythonCoder:
try:
async for result in run_code(
**this_iteration.query.args,
conversation_history=construct_tool_chat_history(
previous_iterations, ConversationCommand.PythonCoder
),
context="",
location_data=location,
user=user,
send_status_func=send_status_func,
query_images=query_images,
query_files=query_files,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
code_results: Dict[str, Dict] = result # type: ignore
this_iteration.codeContext = code_results
async for result in send_status_func(f"**Ran code snippets**: {len(this_iteration.codeContext)}"):
yield result
except (ValueError, TypeError) as e:
this_iteration.warning = f"Error running code: {e}"
logger.warning(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.OperateComputer:
try:
async for result in operate_environment(
**this_iteration.query.args,
user=user,
conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
location_data=location,
previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
send_status_func=send_status_func,
query_images=query_images,
agent=agent,
query_files=query_files,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, OperatorRun):
operator_results = result
this_iteration.operatorContext = operator_results
# Add webpages visited while operating browser to references
if result.webpages:
if not online_results.get(this_iteration.query):
online_results[this_iteration.query] = {"webpages": result.webpages}
elif not online_results[this_iteration.query].get("webpages"):
online_results[this_iteration.query]["webpages"] = result.webpages
else:
online_results[this_iteration.query]["webpages"] += result.webpages
this_iteration.onlineContext = online_results
except Exception as e:
this_iteration.warning = f"Error operating browser: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ViewFile:
try:
async for result in view_file_content(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = result # type: ignore
this_iteration.context += document_results
async for result in send_status_func(f"**Viewed file**: {this_iteration.query.args['path']}"):
yield result
except Exception as e:
this_iteration.warning = f"Error viewing file: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.ListFiles:
try:
async for result in list_files(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = [result] # type: ignore
this_iteration.context += document_results
async for result in send_status_func(result["query"]):
yield result
except Exception as e:
this_iteration.warning = f"Error listing files: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif this_iteration.query.name == ConversationCommand.RegexSearchFiles:
try:
async for result in grep_files(
**this_iteration.query.args,
user=user,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
if this_iteration.context is None:
this_iteration.context = []
document_results: List[Dict[str, str]] = [result] # type: ignore
this_iteration.context += document_results
async for result in send_status_func(result["query"]):
yield result
except Exception as e:
this_iteration.warning = f"Error searching with regex: {e}"
logger.error(this_iteration.warning, exc_info=True)
elif "/" in this_iteration.query.name:
try:
# Identify MCP client to use
server_name, tool_name = this_iteration.query.name.split("/", 1)
mcp_client = next((client for client in mcp_clients if client.name == server_name), None)
if not mcp_client:
raise ValueError(f"Could not find MCP server with name {server_name}")
# Invoke tool on the identified MCP server
mcp_results = await mcp_client.run_tool(tool_name, this_iteration.query.args)
# Record tool result in context
if this_iteration.context is None:
this_iteration.context = []
this_iteration.context += mcp_results
async for result in send_status_func(f"**Used MCP Tool**: {tool_name} on {mcp_client.name}"):
yield result
except Exception as e:
this_iteration.warning = f"Error using MCP tool: {e}"
logger.error(this_iteration.warning, exc_info=True)
else:
# No valid tools. This is our exit condition.
current_iteration = MAX_ITERATIONS
current_iteration += 1
if (
document_results
or online_results
or code_results
or operator_results
or mcp_results
or this_iteration.warning
):
results_data = f"\n<iteration_{current_iteration}_results>"
if document_results:
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
if online_results:
results_data += f"\n<online_results>\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
if code_results:
results_data += f"\n<code_results>\n{yaml.dump(truncate_code_context(code_results), allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
if operator_results:
results_data += (
f"\n<browser_operator_results>\n{operator_results.response}\n</browser_operator_results>"
)
if mcp_results:
results_data += f"\n<mcp_tool_results>\n{yaml.dump(mcp_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</mcp_tool_results>"
if this_iteration.warning:
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
results_data += f"\n</iteration_{current_iteration}_results>"
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = results_data
this_iteration.summarizedResult = (
this_iteration.summarizedResult
or f"<iteration_{current_iteration}_results>Failed to get results.</iteration_{current_iteration}_results>"
)
previous_iterations.append(this_iteration)
yield this_iteration
# Close MCP client connections # Close MCP client connections
for mcp_client in mcp_clients: for mcp_client in mcp_clients: