mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Fix streaming thoughts from multi-turn tools after parallel tool calling
Single turn tools are still executed in parallel. Multi turn tools like operator are executed in serial.
This commit is contained in:
@@ -3057,7 +3057,7 @@ async def view_file_content(
|
||||
start_line: Optional[int] = None,
|
||||
end_line: Optional[int] = None,
|
||||
user: KhojUser = None,
|
||||
):
|
||||
) -> AsyncGenerator[List[Dict[str, str]], None]:
|
||||
"""
|
||||
View the contents of a file from the user's document database with optional line range specification.
|
||||
"""
|
||||
|
||||
@@ -207,41 +207,12 @@ async def execute_tool(
|
||||
async for _ in status_collector(f"**Ran code snippets**: {len(iteration.codeContext)}"):
|
||||
pass
|
||||
|
||||
elif iteration.query.name == ConversationCommand.OperateComputer:
|
||||
async for res in operate_environment(
|
||||
**iteration.query.args,
|
||||
user=user,
|
||||
conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
|
||||
location_data=location,
|
||||
previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
|
||||
send_status_func=status_collector,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
cancellation_event=cancellation_event,
|
||||
interrupt_queue=interrupt_queue,
|
||||
tracer=tracer,
|
||||
):
|
||||
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
|
||||
if isinstance(res, OperatorRun):
|
||||
result.operator_results = res
|
||||
iteration.operatorContext = result.operator_results
|
||||
if res.webpages:
|
||||
if not result.online_results.get(iteration.query):
|
||||
result.online_results[iteration.query] = {"webpages": res.webpages}
|
||||
elif not result.online_results[iteration.query].get("webpages"):
|
||||
result.online_results[iteration.query]["webpages"] = res.webpages
|
||||
else:
|
||||
result.online_results[iteration.query]["webpages"] += res.webpages
|
||||
iteration.onlineContext = result.online_results
|
||||
|
||||
elif iteration.query.name == ConversationCommand.ViewFile:
|
||||
async for res in view_file_content(
|
||||
**iteration.query.args,
|
||||
user=user,
|
||||
):
|
||||
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
|
||||
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
|
||||
if res and isinstance(res, list):
|
||||
if iteration.context is None:
|
||||
iteration.context = []
|
||||
result.document_results = res
|
||||
@@ -586,33 +557,95 @@ async def research(
|
||||
iterations_to_process.append(result)
|
||||
yield result
|
||||
|
||||
# Multi-turn tools that stream their execution
|
||||
streaming_tools = {ConversationCommand.OperateComputer}
|
||||
if iterations_to_process:
|
||||
# Create tasks for parallel execution
|
||||
tasks = [
|
||||
execute_tool(
|
||||
iteration=iteration,
|
||||
user=user,
|
||||
conversation_id=conversation_id,
|
||||
previous_iterations=previous_iterations,
|
||||
location=location,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
max_document_searches=max_document_searches,
|
||||
max_online_searches=max_online_searches,
|
||||
mcp_clients=mcp_clients,
|
||||
cancellation_event=cancellation_event,
|
||||
interrupt_queue=interrupt_queue,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
for iteration in iterations_to_process
|
||||
]
|
||||
# Separate streaming tools that need real-time status updates
|
||||
# from parallelizable tools that can batch their status messages
|
||||
streaming_iterations: list[ResearchIteration] = []
|
||||
parallel_iterations: list[ResearchIteration] = []
|
||||
for iteration in iterations_to_process:
|
||||
if isinstance(iteration.query, ToolCall) and iteration.query.name in streaming_tools:
|
||||
streaming_iterations.append(iteration)
|
||||
else:
|
||||
parallel_iterations.append(iteration)
|
||||
|
||||
# Execute all tools in parallel
|
||||
tool_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
# Execute streaming tools sequentially for real-time status updates
|
||||
streaming_results: list[tuple[ResearchIteration, ToolExecutionResult]] = []
|
||||
for iteration in streaming_iterations:
|
||||
result = ToolExecutionResult()
|
||||
if (
|
||||
isinstance(iteration.query, ToolCall)
|
||||
and iteration.query.name == ConversationCommand.OperateComputer
|
||||
):
|
||||
try:
|
||||
# Execute OperateComputer
|
||||
async for res in operate_environment(
|
||||
**iteration.query.args,
|
||||
user=user,
|
||||
conversation_log=construct_tool_chat_history(
|
||||
previous_iterations, ConversationCommand.Operator
|
||||
),
|
||||
location_data=location,
|
||||
previous_trajectory=previous_iterations[-1].operatorContext
|
||||
if previous_iterations
|
||||
else None,
|
||||
send_status_func=send_status_func,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
cancellation_event=cancellation_event,
|
||||
interrupt_queue=interrupt_queue,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(res, dict) and ChatEvent.STATUS in res:
|
||||
yield res[ChatEvent.STATUS]
|
||||
elif isinstance(res, OperatorRun):
|
||||
result.operator_results = res
|
||||
iteration.operatorContext = result.operator_results
|
||||
if res.webpages:
|
||||
if not result.online_results.get(iteration.query):
|
||||
result.online_results[iteration.query] = {"webpages": res.webpages}
|
||||
elif not result.online_results[iteration.query].get("webpages"):
|
||||
result.online_results[iteration.query]["webpages"] = res.webpages
|
||||
else:
|
||||
result.online_results[iteration.query]["webpages"] += res.webpages
|
||||
iteration.onlineContext = result.online_results
|
||||
except Exception as e:
|
||||
iteration.warning = f"Error operating browser: {e}"
|
||||
logger.error(iteration.warning, exc_info=True)
|
||||
streaming_results.append((iteration, result))
|
||||
|
||||
# Execute parallelizable tools in parallel
|
||||
parallel_results = []
|
||||
if parallel_iterations:
|
||||
tasks = [
|
||||
execute_tool(
|
||||
iteration=iteration,
|
||||
user=user,
|
||||
conversation_id=conversation_id,
|
||||
previous_iterations=previous_iterations,
|
||||
location=location,
|
||||
query_images=query_images,
|
||||
query_files=query_files,
|
||||
max_document_searches=max_document_searches,
|
||||
max_online_searches=max_online_searches,
|
||||
mcp_clients=mcp_clients,
|
||||
cancellation_event=cancellation_event,
|
||||
interrupt_queue=interrupt_queue,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
for iteration in parallel_iterations
|
||||
]
|
||||
tool_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
parallel_results = list(zip(parallel_iterations, tool_results))
|
||||
|
||||
# Combine results (streaming first, then parallel)
|
||||
all_results = streaming_results + parallel_results
|
||||
|
||||
# Process results and yield status messages
|
||||
for this_iteration, tool_result in zip(iterations_to_process, tool_results):
|
||||
for this_iteration, tool_result in all_results:
|
||||
# Handle exceptions from asyncio.gather
|
||||
if isinstance(tool_result, Exception):
|
||||
this_iteration.warning = f"Error executing tool: {tool_result}"
|
||||
|
||||
Reference in New Issue
Block a user