Continue interrupted operator run with new query and previous context

Track research and operator results at each nested iteration step
using python object references + async events bubbled up from nested
iterators.

Instantiates operator with interrupted operator messages from research
or normal mode.

Reflects actual interaction trajectory as closely as possible to agent
including conversation history, partial operator trajectory and new
query for fine grained, corrigible steerability.

Research mode continues with operator tool directly if previous
iteration was an interrupted operator run.
This commit is contained in:
Debanjum
2025-05-29 20:49:55 -07:00
parent de35d91e1d
commit 21bf7f1d6d
6 changed files with 94 additions and 30 deletions

View File

@@ -261,6 +261,27 @@ def construct_question_history(
return history_parts
def construct_chat_history_for_operator(conversation_history: dict, n: int = 6) -> list[AgentMessage]:
"""
Construct chat history for operator agent in conversation log.
Only include last n completed turns (i.e with user and khoj message).
"""
chat_history: list[AgentMessage] = []
user_message: Optional[AgentMessage] = None
for chat in conversation_history.get("chat", []):
if len(chat_history) >= n:
break
if chat["by"] == "you" and chat.get("message"):
content = [{"type": "text", "text": chat["message"]}]
for file in chat.get("queryFiles", []):
content += [{"type": "text", "text": f'## File: {file["name"]}\n\n{file["content"]}'}]
user_message = AgentMessage(role="user", content=content)
elif chat["by"] == "khoj" and chat.get("message"):
chat_history += [user_message, AgentMessage(role="assistant", content=chat["message"])]
return chat_history
def construct_tool_chat_history(
previous_iterations: List[ResearchIteration], tool: ConversationCommand = None
) -> Dict[str, list]:
@@ -285,9 +306,6 @@ def construct_tool_chat_history(
ConversationCommand.Code: (
lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
),
ConversationCommand.Operator: (
lambda iteration: list(iteration.operatorContext.query) if iteration.operatorContext else []
),
}
for iteration in previous_iterations:
# If a tool is provided use the inferred query extractor for that tool if available

View File

@@ -6,7 +6,11 @@ from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation.utils import OperatorRun, construct_chat_history
from khoj.processor.conversation.utils import (
OperatorRun,
construct_chat_history,
construct_chat_history_for_operator,
)
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent
from khoj.processor.operator.operator_agent_base import OperatorAgent
@@ -43,6 +47,10 @@ async def operate_environment(
):
response, user_input_message = None, None
# Only use partial previous trajectories to continue existing task
if previous_trajectory and previous_trajectory.response:
previous_trajectory = None
# Get the agent chat model
agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None
reasoning_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model)
@@ -51,21 +59,19 @@ async def operate_environment(
if not reasoning_model:
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
chat_history = construct_chat_history(conversation_log)
query_with_history = (
f"## Chat History\n{chat_history}\n\n## User Query\n{query}" if chat_history else query
) # Append chat history to query if available
# Create conversation history from conversation log
chat_history = construct_chat_history_for_operator(conversation_log)
# Initialize Agent
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100))
operator_agent: OperatorAgent
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
operator_agent = OpenAIOperatorAgent(
query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer
query, reasoning_model, environment_type, max_iterations, chat_history, previous_trajectory, tracer
)
elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
operator_agent = AnthropicOperatorAgent(
query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer
query, reasoning_model, environment_type, max_iterations, chat_history, previous_trajectory, tracer
)
else:
grounding_model_name = "ui-tars-1.5"
@@ -77,11 +83,12 @@ async def operate_environment(
):
raise ValueError("No supported visual grounding model for binary operator agent found.")
operator_agent = BinaryOperatorAgent(
query_with_history,
query,
reasoning_model,
grounding_model,
environment_type,
max_iterations,
chat_history,
previous_trajectory,
tracer,
)
@@ -100,6 +107,8 @@ async def operate_environment(
try:
task_completed = False
iterations = 0
operator_run = OperatorRun(query=query, trajectory=operator_agent.messages, response=response)
yield operator_run
with timer(
f"Operating {environment_type.value} with {reasoning_model.model_type} {reasoning_model.name}", logger
@@ -159,30 +168,27 @@ async def operate_environment(
# 4. Update agent on the results of its action on the environment
operator_agent.add_action_results(env_steps, agent_result)
operator_run.trajectory = operator_agent.messages
# Determine final response message
if user_input_message:
response = user_input_message
operator_run.response = user_input_message
elif task_completed:
response = summary_message
operator_run.response = summary_message
elif cancellation_event and cancellation_event.is_set():
operator_run.response = None
else: # Hit iteration limit
response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}"
operator_run.response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}"
finally:
if environment and not user_input_message: # Don't close environment if user input required
await environment.close()
if operator_agent:
operator_agent.reset()
webpages = []
if environment_type == EnvironmentType.BROWSER and hasattr(environment, "visited_urls"):
webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls]
operator_run.webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls]
yield OperatorRun(
query=query,
trajectory=operator_agent.messages,
response=response,
webpages=webpages,
)
yield operator_run
def is_operator_model(model: str) -> ChatModel.ModelType | None:

View File

@@ -34,6 +34,7 @@ class OperatorAgent(ABC):
vision_model: ChatModel,
environment_type: EnvironmentType,
max_iterations: int,
chat_history: List[AgentMessage] = [],
previous_trajectory: Optional[OperatorRun] = None,
tracer: dict = {},
):
@@ -42,11 +43,15 @@ class OperatorAgent(ABC):
self.environment_type = environment_type
self.max_iterations = max_iterations
self.tracer = tracer
self.messages: List[AgentMessage] = []
self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}."
self.messages: List[AgentMessage] = chat_history
if previous_trajectory:
self.messages = previous_trajectory.trajectory
# Remove tool call from previous trajectory as tool call w/o result not supported
if previous_trajectory.trajectory and previous_trajectory.trajectory[-1].role == "assistant":
previous_trajectory.trajectory.pop()
self.messages += previous_trajectory.trajectory
self.messages += [AgentMessage(role="user", content=query)]
# Context compression parameters
self.context_compress_trigger = 2e3 # heuristic to determine compression trigger

View File

@@ -40,6 +40,7 @@ class BinaryOperatorAgent(OperatorAgent):
grounding_model: ChatModel,
environment_type: EnvironmentType,
max_iterations: int,
chat_history: List[AgentMessage] = [],
previous_trajectory: Optional[OperatorRun] = None,
tracer: dict = {},
):
@@ -48,6 +49,7 @@ class BinaryOperatorAgent(OperatorAgent):
reasoning_model,
environment_type,
max_iterations,
chat_history,
previous_trajectory,
tracer,
) # Use reasoning model for primary tracking

View File

@@ -1011,7 +1011,7 @@ async def chat(
query=defiltered_query,
conversation_id=conversation_id,
conversation_history=meta_log,
previous_iterations=research_results,
previous_iterations=list(research_results),
query_images=uploaded_images,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
@@ -1030,13 +1030,26 @@ async def chat(
code_results.update(research_result.codeContext)
if research_result.context:
compiled_references.extend(research_result.context)
if research_result.operatorContext:
operator_results.append(research_result.operatorContext)
if not research_results or research_results[-1] is not research_result:
research_results.append(research_result)
else:
yield research_result
# Track operator results across research and operator iterations
# This relies on two conditions:
# 1. Check to append new (partial) operator results
# Relies on triggering this check on every status updates.
# Status updates cascade up from operator to research to chat api on every step.
# 2. Keep operator results in sync with each research operator step
# Relies on python object references to ensure operator results
# are implicitly kept in sync after the initial append
if (
research_results
and research_results[-1].operatorContext
and (not operator_results or operator_results[-1] is not research_results[-1].operatorContext)
):
operator_results.append(research_results[-1].operatorContext)
# researched_results = await extract_relevant_info(q, researched_results, agent)
if state.verbose > 1:
logger.debug(f'Researched Results: {"".join(r.summarizedResult for r in research_results)}')
@@ -1294,7 +1307,7 @@ async def chat(
user,
meta_log,
location,
operator_results[-1] if operator_results else None,
list(operator_results)[-1] if operator_results else None,
query_images=uploaded_images,
query_files=attached_file_context,
send_status_func=partial(send_event, ChatEvent.STATUS),
@@ -1305,7 +1318,8 @@ async def chat(
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, OperatorRun):
operator_results.append(result)
if not operator_results or operator_results[-1] is not result:
operator_results.append(result)
# Add webpages visited while operating browser to references
if result.webpages:
if not online_results.get(defiltered_query):

View File

@@ -96,6 +96,24 @@ async def apick_next_tool(
):
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
# Continue with previous iteration if a multi-step tool use is in progress
if (
previous_iterations
and previous_iterations[-1].tool == ConversationCommand.Operator
and not previous_iterations[-1].summarizedResult
):
previous_iteration = previous_iterations[-1]
yield ResearchIteration(
tool=previous_iteration.tool,
query=query,
context=previous_iteration.context,
onlineContext=previous_iteration.onlineContext,
codeContext=previous_iteration.codeContext,
operatorContext=previous_iteration.operatorContext,
warning=previous_iteration.warning,
)
return
# Construct tool options for the agent to choose from
tool_options = dict()
tool_options_str = ""
@@ -274,6 +292,7 @@ async def research(
yield result[ChatEvent.STATUS]
elif isinstance(result, ResearchIteration):
this_iteration = result
yield this_iteration
# Skip running iteration if warning present in iteration
if this_iteration.warning: