Pass previous trajectory to operator agents for context

This commit is contained in:
Debanjum
2025-05-29 18:31:01 -07:00
parent 864e0ac8b5
commit de35d91e1d
5 changed files with 51 additions and 10 deletions

View File

@@ -6,7 +6,7 @@ from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters from khoj.database.adapters import AgentAdapters, ConversationAdapters
from khoj.database.models import Agent, ChatModel, KhojUser from khoj.database.models import Agent, ChatModel, KhojUser
from khoj.processor.conversation.utils import OperatorRun from khoj.processor.conversation.utils import OperatorRun, construct_chat_history
from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAgent
from khoj.processor.operator.operator_agent_base import OperatorAgent from khoj.processor.operator.operator_agent_base import OperatorAgent
@@ -32,6 +32,7 @@ async def operate_environment(
user: KhojUser, user: KhojUser,
conversation_log: dict, conversation_log: dict,
location_data: LocationData, location_data: LocationData,
previous_trajectory: Optional[OperatorRun] = None,
environment_type: EnvironmentType = EnvironmentType.COMPUTER, environment_type: EnvironmentType = EnvironmentType.COMPUTER,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None, # TODO: Handle query images query_images: Optional[List[str]] = None, # TODO: Handle query images
@@ -50,13 +51,22 @@ async def operate_environment(
if not reasoning_model: if not reasoning_model:
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.") raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
chat_history = construct_chat_history(conversation_log)
query_with_history = (
f"## Chat History\n{chat_history}\n\n## User Query\n{query}" if chat_history else query
) # Append chat history to query if available
# Initialize Agent # Initialize Agent
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100)) max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100))
operator_agent: OperatorAgent operator_agent: OperatorAgent
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI: if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
operator_agent = OpenAIOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer) operator_agent = OpenAIOperatorAgent(
query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer
)
elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC: elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
operator_agent = AnthropicOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer) operator_agent = AnthropicOperatorAgent(
query_with_history, reasoning_model, environment_type, max_iterations, previous_trajectory, tracer
)
else: else:
grounding_model_name = "ui-tars-1.5" grounding_model_name = "ui-tars-1.5"
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name) grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
@@ -67,7 +77,13 @@ async def operate_environment(
): ):
raise ValueError("No supported visual grounding model for binary operator agent found.") raise ValueError("No supported visual grounding model for binary operator agent found.")
operator_agent = BinaryOperatorAgent( operator_agent = BinaryOperatorAgent(
query, reasoning_model, grounding_model, environment_type, max_iterations, tracer query_with_history,
reasoning_model,
grounding_model,
environment_type,
max_iterations,
previous_trajectory,
tracer,
) )
# Initialize Environment # Initialize Environment

View File

@@ -5,7 +5,11 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
from khoj.database.models import ChatModel from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import AgentMessage, commit_conversation_trace from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun,
commit_conversation_trace,
)
from khoj.processor.operator.operator_actions import OperatorAction from khoj.processor.operator.operator_actions import OperatorAction
from khoj.processor.operator.operator_environment_base import ( from khoj.processor.operator.operator_environment_base import (
EnvironmentType, EnvironmentType,
@@ -25,7 +29,13 @@ class AgentActResult(BaseModel):
class OperatorAgent(ABC): class OperatorAgent(ABC):
def __init__( def __init__(
self, query: str, vision_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, tracer: dict self,
query: str,
vision_model: ChatModel,
environment_type: EnvironmentType,
max_iterations: int,
previous_trajectory: Optional[OperatorRun] = None,
tracer: dict = {},
): ):
self.query = query self.query = query
self.vision_model = vision_model self.vision_model = vision_model
@@ -35,6 +45,9 @@ class OperatorAgent(ABC):
self.messages: List[AgentMessage] = [] self.messages: List[AgentMessage] = []
self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}." self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}."
if previous_trajectory:
self.messages = previous_trajectory.trajectory
# Context compression parameters # Context compression parameters
self.context_compress_trigger = 2e3 # heuristic to determine compression trigger self.context_compress_trigger = 2e3 # heuristic to determine compression trigger
# turns after which compression triggered. scales with model max context size. Minimum 5 turns. # turns after which compression triggered. scales with model max context size. Minimum 5 turns.

View File

@@ -2,10 +2,14 @@ import json
import logging import logging
from datetime import datetime from datetime import datetime
from textwrap import dedent from textwrap import dedent
from typing import List from typing import List, Optional
from khoj.database.models import ChatModel from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import AgentMessage, construct_structured_message from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun,
construct_structured_message,
)
from khoj.processor.operator.grounding_agent import GroundingAgent from khoj.processor.operator.grounding_agent import GroundingAgent
from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars from khoj.processor.operator.grounding_agent_uitars import GroundingAgentUitars
from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_actions import *
@@ -36,10 +40,16 @@ class BinaryOperatorAgent(OperatorAgent):
grounding_model: ChatModel, grounding_model: ChatModel,
environment_type: EnvironmentType, environment_type: EnvironmentType,
max_iterations: int, max_iterations: int,
tracer: dict, previous_trajectory: Optional[OperatorRun] = None,
tracer: dict = {},
): ):
super().__init__( super().__init__(
query, reasoning_model, environment_type, max_iterations, tracer query,
reasoning_model,
environment_type,
max_iterations,
previous_trajectory,
tracer,
) # Use reasoning model for primary tracking ) # Use reasoning model for primary tracking
self.reasoning_model = reasoning_model self.reasoning_model = reasoning_model
self.grounding_model = grounding_model self.grounding_model = grounding_model

View File

@@ -1294,6 +1294,7 @@ async def chat(
user, user,
meta_log, meta_log,
location, location,
operator_results[-1] if operator_results else None,
query_images=uploaded_images, query_images=uploaded_images,
query_files=attached_file_context, query_files=attached_file_context,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),

View File

@@ -423,6 +423,7 @@ async def research(
user, user,
construct_tool_chat_history(previous_iterations, ConversationCommand.Operator), construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
location, location,
previous_iterations[-1].operatorContext if previous_iterations else None,
send_status_func=send_status_func, send_status_func=send_status_func,
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,