From 21bf7f1d6df6ad9c2d9850142053c5a43d861a4e Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 29 May 2025 20:49:55 -0700 Subject: [PATCH] Continue interrupted operator run with new query and previous context Track research and operator results at each nested iteration step using python object references + async events bubbled up from nested iterators. Instantiates operator with interrupted operator messages from research or normal mode. Reflects actual interaction trajectory as closely as possible to agent including conversation history, partial operator trajectory and new query for fine grained, corrigible steerability. Research mode continues with operator tool directly if previous iteration was an interrupted operator run. --- src/khoj/processor/conversation/utils.py | 24 ++++++++-- src/khoj/processor/operator/__init__.py | 44 +++++++++++-------- .../processor/operator/operator_agent_base.py | 9 +++- .../operator/operator_agent_binary.py | 2 + src/khoj/routers/api_chat.py | 26 ++++++++--- src/khoj/routers/research.py | 19 ++++++++ 6 files changed, 94 insertions(+), 30 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index c6de9557..52acfbfd 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -261,6 +261,27 @@ def construct_question_history( return history_parts +def construct_chat_history_for_operator(conversation_history: dict, n: int = 6) -> list[AgentMessage]: + """ + Construct chat history for operator agent in conversation log. + Only include last n completed turns (i.e with user and khoj message). + """ + chat_history: list[AgentMessage] = [] + user_message: Optional[AgentMessage] = None + + for chat in conversation_history.get("chat", []): + if len(chat_history) >= n: + break + if chat["by"] == "you" and chat.get("message"): + content = [{"type": "text", "text": chat["message"]}] + for file in chat.get("queryFiles", []): + content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}] + user_message = AgentMessage(role="user", content=content) + elif chat["by"] == "khoj" and chat.get("message"): + chat_history += [user_message, AgentMessage(role="assistant", content=chat["message"])] + return chat_history + + def construct_tool_chat_history( previous_iterations: List[ResearchIteration], tool: ConversationCommand = None ) -> Dict[str, list]: @@ -285,9 +306,6 @@ def construct_tool_chat_history( ConversationCommand.Code: ( lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else [] ), - ConversationCommand.Operator: ( - lambda iteration: list(iteration.operatorContext.query) if iteration.operatorContext else [] - ), } for iteration in previous_iterations: # If a tool is provided use the inferred query extractor for that tool if available diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index cb606c67..34f4ad7d 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -6,7 +6,11 @@ from typing import Callable, List, Optional from khoj.database.adapters import AgentAdapters, ConversationAdapters from khoj.database.models import Agent, ChatModel, KhojUser -from khoj.processor.conversation.utils import OperatorRun, construct_chat_history +from khoj.processor.conversation.utils import ( + OperatorRun, + construct_chat_history, + construct_chat_history_for_operator, +) from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent from khoj.processor.operator.operator_agent_base import OperatorAgent @@ -43,6 +47,10 @@ async def operate_environment( ): response, user_input_message = None, None + # Only use partial previous trajectories to continue existing task + if previous_trajectory and previous_trajectory.response: + previous_trajectory = None + # Get the agent chat model agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None reasoning_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model) @@ -51,21 +59,19 @@ async def operate_environment( if not reasoning_model: raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.") - chat_history = construct_chat_history(conversation_log) - query_with_history = ( - f"## Chat History\n{chat_history}\n\n## User Query\n{query}" if chat_history else query - ) # Append chat history to query if available + # Create conversation history from conversation log + chat_history = construct_chat_history_for_operator(conversation_log) # Initialize Agent max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100)) operator_agent: OperatorAgent if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI: operator_agent = OpenAIOperatorAgent( - query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer + query, reasoning_model, environment_type, max_iterations, chat_history, previous_trajectory, tracer ) elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC: operator_agent = AnthropicOperatorAgent( - query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer + query, reasoning_model, environment_type, max_iterations, chat_history, previous_trajectory, tracer ) else: grounding_model_name = "ui-tars-1.5" @@ -77,11 +83,12 @@ async def operate_environment( ): raise ValueError("No supported visual grounding model for binary operator agent found.") operator_agent = BinaryOperatorAgent( - query_with_history, + query, reasoning_model, grounding_model, environment_type, max_iterations, + chat_history, previous_trajectory, tracer, ) @@ -100,6 +107,8 @@ async def operate_environment( try: task_completed = False iterations = 0 + operator_run = OperatorRun(query=query, trajectory=operator_agent.messages, response=response) + yield operator_run with timer( f"Operating {environment_type.value} with {reasoning_model.model_type} {reasoning_model.name}", logger @@ -159,30 +168,27 @@ async def operate_environment( # 4. Update agent on the results of its action on the environment operator_agent.add_action_results(env_steps, agent_result) + operator_run.trajectory = operator_agent.messages # Determine final response message if user_input_message: - response = user_input_message + operator_run.response = user_input_message elif task_completed: - response = summary_message + operator_run.response = summary_message + elif cancellation_event and cancellation_event.is_set(): + operator_run.response = None else: # Hit iteration limit - response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}" + operator_run.response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}" finally: if environment and not user_input_message: # Don't close environment if user input required await environment.close() if operator_agent: operator_agent.reset() - webpages = [] if environment_type == EnvironmentType.BROWSER and hasattr(environment, "visited_urls"): - webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls] + operator_run.webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls] - yield OperatorRun( - query=query, - trajectory=operator_agent.messages, - response=response, - webpages=webpages, - ) + yield operator_run def is_operator_model(model: str) -> ChatModel.ModelType | None: diff --git a/src/khoj/processor/operator/operator_agent_base.py b/src/khoj/processor/operator/operator_agent_base.py index 23b99362..8686b708 100644 --- a/src/khoj/processor/operator/operator_agent_base.py +++ b/src/khoj/processor/operator/operator_agent_base.py @@ -34,6 +34,7 @@ class OperatorAgent(ABC): vision_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, + chat_history: List[AgentMessage] = [], previous_trajectory: Optional[OperatorRun] = None, tracer: dict = {}, ): @@ -42,11 +43,15 @@ class OperatorAgent(ABC): self.environment_type = environment_type self.max_iterations = max_iterations self.tracer = tracer - self.messages: List[AgentMessage] = [] self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}." + self.messages: List[AgentMessage] = chat_history if previous_trajectory: - self.messages = previous_trajectory.trajectory + # Remove tool call from previous trajectory as tool call w/o result not supported + if previous_trajectory.trajectory and previous_trajectory.trajectory[-1].role == "assistant": + previous_trajectory.trajectory.pop() + self.messages += previous_trajectory.trajectory + self.messages += [AgentMessage(role="user", content=query)] # Context compression parameters self.context_compress_trigger = 2e3 # heuristic to determine compression trigger diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index 2a058c42..ade98176 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -40,6 +40,7 @@ class BinaryOperatorAgent(OperatorAgent): grounding_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, + chat_history: List[AgentMessage] = [], previous_trajectory: Optional[OperatorRun] = None, tracer: dict = {}, ): @@ -48,6 +49,7 @@ class BinaryOperatorAgent(OperatorAgent): reasoning_model, environment_type, max_iterations, + chat_history, previous_trajectory, tracer, ) # Use reasoning model for primary tracking diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 23571738..a7bb513f 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1011,7 +1011,7 @@ async def chat( query=defiltered_query, conversation_id=conversation_id, conversation_history=meta_log, - previous_iterations=research_results, + previous_iterations=list(research_results), query_images=uploaded_images, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), @@ -1030,13 +1030,26 @@ async def chat( code_results.update(research_result.codeContext) if research_result.context: compiled_references.extend(research_result.context) - if research_result.operatorContext: - operator_results.append(research_result.operatorContext) + if not research_results or research_results[-1] is not research_result: research_results.append(research_result) - else: yield research_result + # Track operator results across research and operator iterations + # This relies on two conditions: + # 1. Check to append new (partial) operator results + # Relies on triggering this check on every status updates. + # Status updates cascade up from operator to research to chat api on every step. + # 2. Keep operator results in sync with each research operator step + # Relies on python object references to ensure operator results + # are implicitly kept in sync after the initial append + if ( + research_results + and research_results[-1].operatorContext + and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext) + ): + operator_results.append(research_results[-1].operatorContext) + # researched_results = await extract_relevant_info(q, researched_results, agent) if state.verbose > 1: logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}') @@ -1294,7 +1307,7 @@ async def chat( user, meta_log, location, - operator_results[-1] if operator_results else None, + list(operator_results)[-1] if operator_results else None, query_images=uploaded_images, query_files=attached_file_context, send_status_func=partial(send_event, ChatEvent.STATUS), @@ -1305,7 +1318,8 @@ async def chat( if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] elif isinstance(result, OperatorRun): - operator_results.append(result) + if not operator_results or operator_results[-1] is not result: + operator_results.append(result) # Add webpages visited while operating browser to references if result.webpages: if not online_results.get(defiltered_query): diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 902c2e0c..83ec141a 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -96,6 +96,24 @@ async def apick_next_tool( ): """Given a query, determine which of the available tools the agent should use in order to answer appropriately.""" + # Continue with previous iteration if a multi-step tool use is in progress + if ( + previous_iterations + and previous_iterations[-1].tool == ConversationCommand.Operator + and not previous_iterations[-1].summarizedResult + ): + previous_iteration = previous_iterations[-1] + yield ResearchIteration( + tool=previous_iteration.tool, + query=query, + context=previous_iteration.context, + onlineContext=previous_iteration.onlineContext, + codeContext=previous_iteration.codeContext, + operatorContext=previous_iteration.operatorContext, + warning=previous_iteration.warning, + ) + return + # Construct tool options for the agent to choose from tool_options = dict() tool_options_str = "" @@ -274,6 +292,7 @@ async def research( yield result[ChatEvent.STATUS] elif isinstance(result, ResearchIteration): this_iteration = result + yield this_iteration # Skip running iteration if warning present in iteration if this_iteration.warning: