From d54bfc19e5b2e91dba0a248df95de1967c92933e Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 28 May 2025 00:28:34 -0700 Subject: [PATCH] Add trajectory compression to anthropic operator agent - Add compression parameters to base operator agent for reuse - Increase default operator iterations --- src/khoj/processor/operator/__init__.py | 2 +- .../operator/operator_agent_anthropic.py | 20 ++++++++++++++++++- .../processor/operator/operator_agent_base.py | 11 ++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index ec531675..b2ea846f 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -50,7 +50,7 @@ async def operate_environment( raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.") # Initialize Agent - max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 40)) + max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100)) operator_agent: OperatorAgent if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI: operator_agent = OpenAIOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer) diff --git a/src/khoj/processor/operator/operator_agent_anthropic.py b/src/khoj/processor/operator/operator_agent_anthropic.py index f8837c2f..ad99bb34 100644 --- a/src/khoj/processor/operator/operator_agent_anthropic.py +++ b/src/khoj/processor/operator/operator_agent_anthropic.py @@ -6,7 +6,7 @@ from datetime import datetime from textwrap import dedent from typing import List, Literal, Optional, cast -from anthropic.types.beta import BetaContentBlock +from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlock from khoj.processor.conversation.anthropic.utils import is_reasoning_model from khoj.processor.operator.operator_actions import * @@ -47,6 +47,24 @@ class AnthropicOperatorAgent(OperatorAgent): if is_reasoning_model(self.vision_model.name): thinking = {"type": "enabled", "budget_tokens": 1024} + # Trigger trajectory compression if exceed size limit + if len(self.messages) > self.message_limit: + # 1. Prepare messages for compression + original_messages = self.messages + self.messages = self.messages[: self.compress_length] + # ensure last message isn't a tool call request + if self.messages[-1].role == "assistant" and any( + isinstance(block, BetaToolUseBlock) for block in self.messages[-1].content + ): + self.messages.pop() + # 2. Get summary of operation trajectory + await self.summarize(current_state) + # 3. Rebuild condensed trajectory + primary_task = [original_messages.pop(0)] + condensed_trajectory = self.messages[-2:] # extract summary request, response + recent_trajectory = original_messages[self.compress_length :] + self.messages = primary_task + condensed_trajectory + recent_trajectory + messages_for_api = self._format_message_for_api(self.messages) response = await client.beta.messages.create( messages=messages_for_api, diff --git a/src/khoj/processor/operator/operator_agent_base.py b/src/khoj/processor/operator/operator_agent_base.py index eb9bd544..8c273b5e 100644 --- a/src/khoj/processor/operator/operator_agent_base.py +++ b/src/khoj/processor/operator/operator_agent_base.py @@ -40,6 +40,17 @@ class OperatorAgent(ABC): self.messages: List[AgentMessage] = [] self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}." + # Context compression parameters + self.context_compress_trigger = 2e3 # heuristic to determine compression trigger + # turns after which compression triggered. scales with model max context size. Minimum 5 turns. + self.message_limit = 2 * max( + 5, int(self.vision_model.subscribed_max_prompt_size / self.context_compress_trigger) + ) + # compression ratio determines how many messages to compress down to one + # e.g. if 5 messages, a compress ratio of 4/5 means compress 5 messages into 1 + keep 1 uncompressed + self.message_compress_ratio = 4 / 5 + self.compress_length = int(self.message_limit * self.message_compress_ratio) + @abstractmethod async def act(self, current_state: EnvState) -> AgentActResult: pass