Pass previous trajectory to operator agents for context

This commit is contained in:
Debanjum
2025-05-29 18:31:01 -07:00
parent 864e0ac8b5
commit de35d91e1d
5 changed files with 51 additions and 10 deletions

View File

@@ -6,7 +6,7 @@ from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation.utils import OperatorRun
from khoj.processor.conversation.utils import OperatorRun, construct_chat_history
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent
from khoj.processor.operator.operator_agent_base import OperatorAgent
@@ -32,6 +32,7 @@ async def operate_environment(
user: KhojUser,
conversation_log: dict,
location_data: LocationData,
previous_trajectory: Optional[OperatorRun] = None,
environment_type: EnvironmentType = EnvironmentType.COMPUTER,
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None, # TODO: Handle query images
@@ -50,13 +51,22 @@ async def operate_environment(
if not reasoning_model:
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
chat_history = construct_chat_history(conversation_log)
query_with_history = (
f"## Chat History\n{chat_history}\n\n## User Query\n{query}" if chat_history else query
) # Append chat history to query if available
# Initialize Agent
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100))
operator_agent: OperatorAgent
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
operator_agent = OpenAIOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer)
operator_agent = OpenAIOperatorAgent(
query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer
)
elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
operator_agent = AnthropicOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer)
operator_agent = AnthropicOperatorAgent(
query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer
)
else:
grounding_model_name = "ui-tars-1.5"
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
@@ -67,7 +77,13 @@ async def operate_environment(
):
raise ValueError("No supported visual grounding model for binary operator agent found.")
operator_agent = BinaryOperatorAgent(
query, reasoning_model, grounding_model, environment_type, max_iterations, tracer
query_with_history,
reasoning_model,
grounding_model,
environment_type,
max_iterations,
previous_trajectory,
tracer,
)
# Initialize Environment

View File

@@ -5,7 +5,11 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel
from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import AgentMessage, commit_conversation_trace
from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun,
commit_conversation_trace,
)
from khoj.processor.operator.operator_actions import OperatorAction
from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
@@ -25,7 +29,13 @@ class AgentActResult(BaseModel):
class OperatorAgent(ABC):
def __init__(
self, query: str, vision_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, tracer: dict
self,
query: str,
vision_model: ChatModel,
environment_type: EnvironmentType,
max_iterations: int,
previous_trajectory: Optional[OperatorRun] = None,
tracer: dict = {},
):
self.query = query
self.vision_model = vision_model
@@ -35,6 +45,9 @@ class OperatorAgent(ABC):
self.messages: List[AgentMessage] = []
self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}."
if previous_trajectory:
self.messages = previous_trajectory.trajectory
# Context compression parameters
self.context_compress_trigger = 2e3 # heuristic to determine compression trigger
# turns after which compression triggered. scales with model max context size. Minimum 5 turns.

View File

@@ -2,10 +2,14 @@ import json
import logging
from datetime import datetime
from textwrap import dedent
from typing import List
from typing import List, Optional
from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import AgentMessage, construct_structured_message
from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun,
construct_structured_message,
)
from khoj.processor.operator.grounding_agent import GroundingAgent
from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars
from khoj.processor.operator.operator_actions import *
@@ -36,10 +40,16 @@ class BinaryOperatorAgent(OperatorAgent):
grounding_model: ChatModel,
environment_type: EnvironmentType,
max_iterations: int,
tracer: dict,
previous_trajectory: Optional[OperatorRun] = None,
tracer: dict = {},
):
super().__init__(
query, reasoning_model, environment_type, max_iterations, tracer
query,
reasoning_model,
environment_type,
max_iterations,
previous_trajectory,
tracer,
) # Use reasoning model for primary tracking
self.reasoning_model = reasoning_model
self.grounding_model = grounding_model

View File

@@ -1294,6 +1294,7 @@ async def chat(
user,
meta_log,
location,
operator_results[-1] if operator_results else None,
query_images=uploaded_images,
query_files=attached_file_context,
send_status_func=partial(send_event, ChatEvent.STATUS),

View File

@@ -423,6 +423,7 @@ async def research(
user,
construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
location,
previous_iterations[-1].operatorContext if previous_iterations else None,
send_status_func=send_status_func,
query_images=query_images,
agent=agent,