Add trajectory compression to anthropic operator agent

- Add compression parameters to base operator agent for reuse
- Increase default operator iterations
This commit is contained in:
Debanjum
2025-05-28 00:28:34 -07:00
parent cb451fa67c
commit d54bfc19e5
3 changed files with 31 additions and 2 deletions

View File

@@ -50,7 +50,7 @@ async def operate_environment(
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.") raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
# Initialize Agent # Initialize Agent
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 40)) max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100))
operator_agent: OperatorAgent operator_agent: OperatorAgent
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI: if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
operator_agent = OpenAIOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer) operator_agent = OpenAIOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer)

View File

@@ -6,7 +6,7 @@ from datetime import datetime
from textwrap import dedent from textwrap import dedent
from typing import List, Literal, Optional, cast from typing import List, Literal, Optional, cast
from anthropic.types.beta import BetaContentBlock from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlock
from khoj.processor.conversation.anthropic.utils import is_reasoning_model from khoj.processor.conversation.anthropic.utils import is_reasoning_model
from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_actions import *
@@ -47,6 +47,24 @@ class AnthropicOperatorAgent(OperatorAgent):
if is_reasoning_model(self.vision_model.name): if is_reasoning_model(self.vision_model.name):
thinking = {"type": "enabled", "budget_tokens": 1024} thinking = {"type": "enabled", "budget_tokens": 1024}
# Trigger trajectory compression if exceed size limit
if len(self.messages) > self.message_limit:
# 1. Prepare messages for compression
original_messages = self.messages
self.messages = self.messages[: self.compress_length]
# ensure last message isn't a tool call request
if self.messages[-1].role == "assistant" and any(
isinstance(block, BetaToolUseBlock) for block in self.messages[-1].content
):
self.messages.pop()
# 2. Get summary of operation trajectory
await self.summarize(current_state)
# 3. Rebuild condensed trajectory
primary_task = [original_messages.pop(0)]
condensed_trajectory = self.messages[-2:] # extract summary request, response
recent_trajectory = original_messages[self.compress_length :]
self.messages = primary_task + condensed_trajectory + recent_trajectory
messages_for_api = self._format_message_for_api(self.messages) messages_for_api = self._format_message_for_api(self.messages)
response = await client.beta.messages.create( response = await client.beta.messages.create(
messages=messages_for_api, messages=messages_for_api,

View File

@@ -40,6 +40,17 @@ class OperatorAgent(ABC):
self.messages: List[AgentMessage] = [] self.messages: List[AgentMessage] = []
self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}." self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}."
# Context compression parameters
self.context_compress_trigger = 2e3 # heuristic to determine compression trigger
# turns after which compression triggered. scales with model max context size. Minimum 5 turns.
self.message_limit = 2 * max(
5, int(self.vision_model.subscribed_max_prompt_size / self.context_compress_trigger)
)
# compression ratio determines how many messages to compress down to one
# e.g. if 5 messages, a compress ratio of 4/5 means compress 5 messages into 1 + keep 1 uncompressed
self.message_compress_ratio = 4 / 5
self.compress_length = int(self.message_limit * self.message_compress_ratio)
@abstractmethod @abstractmethod
async def act(self, current_state: EnvState) -> AgentActResult: async def act(self, current_state: EnvState) -> AgentActResult:
pass pass