mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Fix streaming thoughts from multi-turn tools after parallel tool calling
Single turn tools are still executed in parallel. Multi turn tools like operator are executed in serial.
This commit is contained in:
@@ -3057,7 +3057,7 @@ async def view_file_content(
|
|||||||
start_line: Optional[int] = None,
|
start_line: Optional[int] = None,
|
||||||
end_line: Optional[int] = None,
|
end_line: Optional[int] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
):
|
) -> AsyncGenerator[List[Dict[str, str]], None]:
|
||||||
"""
|
"""
|
||||||
View the contents of a file from the user's document database with optional line range specification.
|
View the contents of a file from the user's document database with optional line range specification.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -207,41 +207,12 @@ async def execute_tool(
|
|||||||
async for _ in status_collector(f"**Ran code snippets**: {len(iteration.codeContext)}"):
|
async for _ in status_collector(f"**Ran code snippets**: {len(iteration.codeContext)}"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif iteration.query.name == ConversationCommand.OperateComputer:
|
|
||||||
async for res in operate_environment(
|
|
||||||
**iteration.query.args,
|
|
||||||
user=user,
|
|
||||||
conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
|
|
||||||
location_data=location,
|
|
||||||
previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None,
|
|
||||||
send_status_func=status_collector,
|
|
||||||
query_images=query_images,
|
|
||||||
agent=agent,
|
|
||||||
query_files=query_files,
|
|
||||||
cancellation_event=cancellation_event,
|
|
||||||
interrupt_queue=interrupt_queue,
|
|
||||||
tracer=tracer,
|
|
||||||
):
|
|
||||||
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
|
|
||||||
if isinstance(res, OperatorRun):
|
|
||||||
result.operator_results = res
|
|
||||||
iteration.operatorContext = result.operator_results
|
|
||||||
if res.webpages:
|
|
||||||
if not result.online_results.get(iteration.query):
|
|
||||||
result.online_results[iteration.query] = {"webpages": res.webpages}
|
|
||||||
elif not result.online_results[iteration.query].get("webpages"):
|
|
||||||
result.online_results[iteration.query]["webpages"] = res.webpages
|
|
||||||
else:
|
|
||||||
result.online_results[iteration.query]["webpages"] += res.webpages
|
|
||||||
iteration.onlineContext = result.online_results
|
|
||||||
|
|
||||||
elif iteration.query.name == ConversationCommand.ViewFile:
|
elif iteration.query.name == ConversationCommand.ViewFile:
|
||||||
async for res in view_file_content(
|
async for res in view_file_content(
|
||||||
**iteration.query.args,
|
**iteration.query.args,
|
||||||
user=user,
|
user=user,
|
||||||
):
|
):
|
||||||
# Status messages are collected by status_collector, skip ChatEvent.STATUS here
|
if res and isinstance(res, list):
|
||||||
if not (isinstance(res, dict) and ChatEvent.STATUS in res):
|
|
||||||
if iteration.context is None:
|
if iteration.context is None:
|
||||||
iteration.context = []
|
iteration.context = []
|
||||||
result.document_results = res
|
result.document_results = res
|
||||||
@@ -586,33 +557,95 @@ async def research(
|
|||||||
iterations_to_process.append(result)
|
iterations_to_process.append(result)
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
# Multi-turn tools that stream their execution
|
||||||
|
streaming_tools = {ConversationCommand.OperateComputer}
|
||||||
if iterations_to_process:
|
if iterations_to_process:
|
||||||
# Create tasks for parallel execution
|
# Separate streaming tools that need real-time status updates
|
||||||
tasks = [
|
# from parallelizable tools that can batch their status messages
|
||||||
execute_tool(
|
streaming_iterations: list[ResearchIteration] = []
|
||||||
iteration=iteration,
|
parallel_iterations: list[ResearchIteration] = []
|
||||||
user=user,
|
for iteration in iterations_to_process:
|
||||||
conversation_id=conversation_id,
|
if isinstance(iteration.query, ToolCall) and iteration.query.name in streaming_tools:
|
||||||
previous_iterations=previous_iterations,
|
streaming_iterations.append(iteration)
|
||||||
location=location,
|
else:
|
||||||
query_images=query_images,
|
parallel_iterations.append(iteration)
|
||||||
query_files=query_files,
|
|
||||||
max_document_searches=max_document_searches,
|
|
||||||
max_online_searches=max_online_searches,
|
|
||||||
mcp_clients=mcp_clients,
|
|
||||||
cancellation_event=cancellation_event,
|
|
||||||
interrupt_queue=interrupt_queue,
|
|
||||||
agent=agent,
|
|
||||||
tracer=tracer,
|
|
||||||
)
|
|
||||||
for iteration in iterations_to_process
|
|
||||||
]
|
|
||||||
|
|
||||||
# Execute all tools in parallel
|
# Execute streaming tools sequentially for real-time status updates
|
||||||
tool_results = await asyncio.gather(*tasks, return_exceptions=True)
|
streaming_results: list[tuple[ResearchIteration, ToolExecutionResult]] = []
|
||||||
|
for iteration in streaming_iterations:
|
||||||
|
result = ToolExecutionResult()
|
||||||
|
if (
|
||||||
|
isinstance(iteration.query, ToolCall)
|
||||||
|
and iteration.query.name == ConversationCommand.OperateComputer
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# Execute OperateComputer
|
||||||
|
async for res in operate_environment(
|
||||||
|
**iteration.query.args,
|
||||||
|
user=user,
|
||||||
|
conversation_log=construct_tool_chat_history(
|
||||||
|
previous_iterations, ConversationCommand.Operator
|
||||||
|
),
|
||||||
|
location_data=location,
|
||||||
|
previous_trajectory=previous_iterations[-1].operatorContext
|
||||||
|
if previous_iterations
|
||||||
|
else None,
|
||||||
|
send_status_func=send_status_func,
|
||||||
|
query_images=query_images,
|
||||||
|
agent=agent,
|
||||||
|
query_files=query_files,
|
||||||
|
cancellation_event=cancellation_event,
|
||||||
|
interrupt_queue=interrupt_queue,
|
||||||
|
tracer=tracer,
|
||||||
|
):
|
||||||
|
if isinstance(res, dict) and ChatEvent.STATUS in res:
|
||||||
|
yield res[ChatEvent.STATUS]
|
||||||
|
elif isinstance(res, OperatorRun):
|
||||||
|
result.operator_results = res
|
||||||
|
iteration.operatorContext = result.operator_results
|
||||||
|
if res.webpages:
|
||||||
|
if not result.online_results.get(iteration.query):
|
||||||
|
result.online_results[iteration.query] = {"webpages": res.webpages}
|
||||||
|
elif not result.online_results[iteration.query].get("webpages"):
|
||||||
|
result.online_results[iteration.query]["webpages"] = res.webpages
|
||||||
|
else:
|
||||||
|
result.online_results[iteration.query]["webpages"] += res.webpages
|
||||||
|
iteration.onlineContext = result.online_results
|
||||||
|
except Exception as e:
|
||||||
|
iteration.warning = f"Error operating browser: {e}"
|
||||||
|
logger.error(iteration.warning, exc_info=True)
|
||||||
|
streaming_results.append((iteration, result))
|
||||||
|
|
||||||
|
# Execute parallelizable tools in parallel
|
||||||
|
parallel_results = []
|
||||||
|
if parallel_iterations:
|
||||||
|
tasks = [
|
||||||
|
execute_tool(
|
||||||
|
iteration=iteration,
|
||||||
|
user=user,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
previous_iterations=previous_iterations,
|
||||||
|
location=location,
|
||||||
|
query_images=query_images,
|
||||||
|
query_files=query_files,
|
||||||
|
max_document_searches=max_document_searches,
|
||||||
|
max_online_searches=max_online_searches,
|
||||||
|
mcp_clients=mcp_clients,
|
||||||
|
cancellation_event=cancellation_event,
|
||||||
|
interrupt_queue=interrupt_queue,
|
||||||
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
|
)
|
||||||
|
for iteration in parallel_iterations
|
||||||
|
]
|
||||||
|
tool_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
parallel_results = list(zip(parallel_iterations, tool_results))
|
||||||
|
|
||||||
|
# Combine results (streaming first, then parallel)
|
||||||
|
all_results = streaming_results + parallel_results
|
||||||
|
|
||||||
# Process results and yield status messages
|
# Process results and yield status messages
|
||||||
for this_iteration, tool_result in zip(iterations_to_process, tool_results):
|
for this_iteration, tool_result in all_results:
|
||||||
# Handle exceptions from asyncio.gather
|
# Handle exceptions from asyncio.gather
|
||||||
if isinstance(tool_result, Exception):
|
if isinstance(tool_result, Exception):
|
||||||
this_iteration.warning = f"Error executing tool: {tool_result}"
|
this_iteration.warning = f"Error executing tool: {tool_result}"
|
||||||
|
|||||||
Reference in New Issue
Block a user