Fix streaming thoughts from multi-turn tools after parallel tool calling

Single turn tools are still executed in parallel. Multi turn tools
like operator are executed in serial.
This commit is contained in:
Debanjum
2025-12-16 18:24:42 -08:00
parent 446a23524c
commit f65f6ae848
2 changed files with 87 additions and 54 deletions

View File

@@ -3057,7 +3057,7 @@ async def view_file_content(
start_line: Optional[int] = None,
end_line: Optional[int] = None,
user: KhojUser = None,
):
) -> AsyncGenerator[List[Dict[str, str]], None]:
"""
View the contents of a file from the user's document database with optional line range specification.
"""

View File

@@ -207,41 +207,12 @@ async def execute_tool(
async for _ in status_collector(f"**Ran code snippets**: {len(iteration.codeContext)}"):
pass
elif iteration.query.name == ConversationCommand.OperateComputer:
async for res in operate_environment(
**iteration.query.args,
user=user,
conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
location_data=location,
previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
send_status_func=status_collector,
query_images=query_images,
agent=agent,
query_files=query_files,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
tracer=tracer,
):
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
if isinstance(res, OperatorRun):
result.operator_results = res
iteration.operatorContext = result.operator_results
if res.webpages:
if not result.online_results.get(iteration.query):
result.online_results[iteration.query] = {"webpages": res.webpages}
elif not result.online_results[iteration.query].get("webpages"):
result.online_results[iteration.query]["webpages"] = res.webpages
else:
result.online_results[iteration.query]["webpages"] += res.webpages
iteration.onlineContext = result.online_results
elif iteration.query.name == ConversationCommand.ViewFile:
async for res in view_file_content(
**iteration.query.args,
user=user,
):
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
if res and isinstance(res, list):
if iteration.context is None:
iteration.context = []
result.document_results = res
@@ -586,33 +557,95 @@ async def research(
iterations_to_process.append(result)
yield result
# Multi-turn tools that stream their execution
streaming_tools = {ConversationCommand.OperateComputer}
if iterations_to_process:
# Create tasks for parallel execution
tasks = [
execute_tool(
iteration=iteration,
user=user,
conversation_id=conversation_id,
previous_iterations=previous_iterations,
location=location,
query_images=query_images,
query_files=query_files,
max_document_searches=max_document_searches,
max_online_searches=max_online_searches,
mcp_clients=mcp_clients,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
agent=agent,
tracer=tracer,
)
for iteration in iterations_to_process
]
# Separate streaming tools that need real-time status updates
# from parallelizable tools that can batch their status messages
streaming_iterations: list[ResearchIteration] = []
parallel_iterations: list[ResearchIteration] = []
for iteration in iterations_to_process:
if isinstance(iteration.query, ToolCall) and iteration.query.name in streaming_tools:
streaming_iterations.append(iteration)
else:
parallel_iterations.append(iteration)
# Execute all tools in parallel
tool_results = await asyncio.gather(*tasks, return_exceptions=True)
# Execute streaming tools sequentially for real-time status updates
streaming_results: list[tuple[ResearchIteration, ToolExecutionResult]] = []
for iteration in streaming_iterations:
result = ToolExecutionResult()
if (
isinstance(iteration.query, ToolCall)
and iteration.query.name == ConversationCommand.OperateComputer
):
try:
# Execute OperateComputer
async for res in operate_environment(
**iteration.query.args,
user=user,
conversation_log=construct_tool_chat_history(
previous_iterations, ConversationCommand.Operator
),
location_data=location,
previous_trajectory=previous_iterations[-1].operatorContext
if previous_iterations
else None,
send_status_func=send_status_func,
query_images=query_images,
agent=agent,
query_files=query_files,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
tracer=tracer,
):
if isinstance(res, dict) and ChatEvent.STATUS in res:
yield res[ChatEvent.STATUS]
elif isinstance(res, OperatorRun):
result.operator_results = res
iteration.operatorContext = result.operator_results
if res.webpages:
if not result.online_results.get(iteration.query):
result.online_results[iteration.query] = {"webpages": res.webpages}
elif not result.online_results[iteration.query].get("webpages"):
result.online_results[iteration.query]["webpages"] = res.webpages
else:
result.online_results[iteration.query]["webpages"] += res.webpages
iteration.onlineContext = result.online_results
except Exception as e:
iteration.warning = f"Error operating browser: {e}"
logger.error(iteration.warning, exc_info=True)
streaming_results.append((iteration, result))
# Execute parallelizable tools in parallel
parallel_results = []
if parallel_iterations:
tasks = [
execute_tool(
iteration=iteration,
user=user,
conversation_id=conversation_id,
previous_iterations=previous_iterations,
location=location,
query_images=query_images,
query_files=query_files,
max_document_searches=max_document_searches,
max_online_searches=max_online_searches,
mcp_clients=mcp_clients,
cancellation_event=cancellation_event,
interrupt_queue=interrupt_queue,
agent=agent,
tracer=tracer,
)
for iteration in parallel_iterations
]
tool_results = await asyncio.gather(*tasks, return_exceptions=True)
parallel_results = list(zip(parallel_iterations, tool_results))
# Combine results (streaming first, then parallel)
all_results = streaming_results + parallel_results
# Process results and yield status messages
for this_iteration, tool_result in zip(iterations_to_process, tool_results):
for this_iteration, tool_result in all_results:
# Handle exceptions from asyncio.gather
if isinstance(tool_result, Exception):
this_iteration.warning = f"Error executing tool: {tool_result}"