diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e5d2de57..1c9f9b33 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -3057,7 +3057,7 @@ async def view_file_content( start_line: Optional[int] = None, end_line: Optional[int] = None, user: KhojUser = None, -): +) -> AsyncGenerator[List[Dict[str, str]], None]: """ View the contents of a file from the user's document database with optional line range specification. """ diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index cef6f49d..e103b330 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -207,41 +207,12 @@ async def execute_tool( async for _ in status_collector(f"**Ran code snippets**: {len(iteration.codeContext)}"): pass - elif iteration.query.name == ConversationCommand.OperateComputer: - async for res in operate_environment( - **iteration.query.args, - user=user, - conversation_log=construct_tool_chat_history(previous_iterations, ConversationCommand.Operator), - location_data=location, - previous_trajectory=previous_iterations[-1].operatorContext if previous_iterations else None, - send_status_func=status_collector, - query_images=query_images, - agent=agent, - query_files=query_files, - cancellation_event=cancellation_event, - interrupt_queue=interrupt_queue, - tracer=tracer, - ): - # Status messages are collected by status_collector, skip ChatEvent.STATUS here - if isinstance(res, OperatorRun): - result.operator_results = res - iteration.operatorContext = result.operator_results - if res.webpages: - if not result.online_results.get(iteration.query): - result.online_results[iteration.query] = {"webpages": res.webpages} - elif not result.online_results[iteration.query].get("webpages"): - result.online_results[iteration.query]["webpages"] = res.webpages - else: - result.online_results[iteration.query]["webpages"] += res.webpages - iteration.onlineContext = result.online_results - elif iteration.query.name == ConversationCommand.ViewFile: async for res in view_file_content( **iteration.query.args, user=user, ): - # Status messages are collected by status_collector, skip ChatEvent.STATUS here - if not (isinstance(res, dict) and ChatEvent.STATUS in res): + if res and isinstance(res, list): if iteration.context is None: iteration.context = [] result.document_results = res @@ -586,33 +557,95 @@ async def research( iterations_to_process.append(result) yield result + # Multi-turn tools that stream their execution + streaming_tools = {ConversationCommand.OperateComputer} if iterations_to_process: - # Create tasks for parallel execution - tasks = [ - execute_tool( - iteration=iteration, - user=user, - conversation_id=conversation_id, - previous_iterations=previous_iterations, - location=location, - query_images=query_images, - query_files=query_files, - max_document_searches=max_document_searches, - max_online_searches=max_online_searches, - mcp_clients=mcp_clients, - cancellation_event=cancellation_event, - interrupt_queue=interrupt_queue, - agent=agent, - tracer=tracer, - ) - for iteration in iterations_to_process - ] + # Separate streaming tools that need real-time status updates + # from parallelizable tools that can batch their status messages + streaming_iterations: list[ResearchIteration] = [] + parallel_iterations: list[ResearchIteration] = [] + for iteration in iterations_to_process: + if isinstance(iteration.query, ToolCall) and iteration.query.name in streaming_tools: + streaming_iterations.append(iteration) + else: + parallel_iterations.append(iteration) - # Execute all tools in parallel - tool_results = await asyncio.gather(*tasks, return_exceptions=True) + # Execute streaming tools sequentially for real-time status updates + streaming_results: list[tuple[ResearchIteration, ToolExecutionResult]] = [] + for iteration in streaming_iterations: + result = ToolExecutionResult() + if ( + isinstance(iteration.query, ToolCall) + and iteration.query.name == ConversationCommand.OperateComputer + ): + try: + # Execute OperateComputer + async for res in operate_environment( + **iteration.query.args, + user=user, + conversation_log=construct_tool_chat_history( + previous_iterations, ConversationCommand.Operator + ), + location_data=location, + previous_trajectory=previous_iterations[-1].operatorContext + if previous_iterations + else None, + send_status_func=send_status_func, + query_images=query_images, + agent=agent, + query_files=query_files, + cancellation_event=cancellation_event, + interrupt_queue=interrupt_queue, + tracer=tracer, + ): + if isinstance(res, dict) and ChatEvent.STATUS in res: + yield res[ChatEvent.STATUS] + elif isinstance(res, OperatorRun): + result.operator_results = res + iteration.operatorContext = result.operator_results + if res.webpages: + if not result.online_results.get(iteration.query): + result.online_results[iteration.query] = {"webpages": res.webpages} + elif not result.online_results[iteration.query].get("webpages"): + result.online_results[iteration.query]["webpages"] = res.webpages + else: + result.online_results[iteration.query]["webpages"] += res.webpages + iteration.onlineContext = result.online_results + except Exception as e: + iteration.warning = f"Error operating browser: {e}" + logger.error(iteration.warning, exc_info=True) + streaming_results.append((iteration, result)) + + # Execute parallelizable tools in parallel + parallel_results = [] + if parallel_iterations: + tasks = [ + execute_tool( + iteration=iteration, + user=user, + conversation_id=conversation_id, + previous_iterations=previous_iterations, + location=location, + query_images=query_images, + query_files=query_files, + max_document_searches=max_document_searches, + max_online_searches=max_online_searches, + mcp_clients=mcp_clients, + cancellation_event=cancellation_event, + interrupt_queue=interrupt_queue, + agent=agent, + tracer=tracer, + ) + for iteration in parallel_iterations + ] + tool_results = await asyncio.gather(*tasks, return_exceptions=True) + parallel_results = list(zip(parallel_iterations, tool_results)) + + # Combine results (streaming first, then parallel) + all_results = streaming_results + parallel_results # Process results and yield status messages - for this_iteration, tool_result in zip(iterations_to_process, tool_results): + for this_iteration, tool_result in all_results: # Handle exceptions from asyncio.gather if isinstance(tool_result, Exception): this_iteration.warning = f"Error executing tool: {tool_result}"