mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Add trajectory compression to anthropic operator agent
- Add compression parameters to base operator agent for reuse - Increase default operator iterations
This commit is contained in:
@@ -50,7 +50,7 @@ async def operate_environment(
|
||||
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
|
||||
|
||||
# Initialize Agent
|
||||
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 40))
|
||||
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 100))
|
||||
operator_agent: OperatorAgent
|
||||
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
|
||||
operator_agent = OpenAIOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer)
|
||||
|
||||
@@ -6,7 +6,7 @@ from datetime import datetime
|
||||
from textwrap import dedent
|
||||
from typing import List, Literal, Optional, cast
|
||||
|
||||
from anthropic.types.beta import BetaContentBlock
|
||||
from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlock
|
||||
|
||||
from khoj.processor.conversation.anthropic.utils import is_reasoning_model
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
@@ -47,6 +47,24 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
if is_reasoning_model(self.vision_model.name):
|
||||
thinking = {"type": "enabled", "budget_tokens": 1024}
|
||||
|
||||
# Trigger trajectory compression if exceed size limit
|
||||
if len(self.messages) > self.message_limit:
|
||||
# 1. Prepare messages for compression
|
||||
original_messages = self.messages
|
||||
self.messages = self.messages[: self.compress_length]
|
||||
# ensure last message isn't a tool call request
|
||||
if self.messages[-1].role == "assistant" and any(
|
||||
isinstance(block, BetaToolUseBlock) for block in self.messages[-1].content
|
||||
):
|
||||
self.messages.pop()
|
||||
# 2. Get summary of operation trajectory
|
||||
await self.summarize(current_state)
|
||||
# 3. Rebuild condensed trajectory
|
||||
primary_task = [original_messages.pop(0)]
|
||||
condensed_trajectory = self.messages[-2:] # extract summary request, response
|
||||
recent_trajectory = original_messages[self.compress_length :]
|
||||
self.messages = primary_task + condensed_trajectory + recent_trajectory
|
||||
|
||||
messages_for_api = self._format_message_for_api(self.messages)
|
||||
response = await client.beta.messages.create(
|
||||
messages=messages_for_api,
|
||||
|
||||
@@ -40,6 +40,17 @@ class OperatorAgent(ABC):
|
||||
self.messages: List[AgentMessage] = []
|
||||
self.summarize_prompt = f"Use the results of our research to provide a comprehensive, self-contained answer for the target query:\n{query}."
|
||||
|
||||
# Context compression parameters
|
||||
self.context_compress_trigger = 2e3 # heuristic to determine compression trigger
|
||||
# turns after which compression triggered. scales with model max context size. Minimum 5 turns.
|
||||
self.message_limit = 2 * max(
|
||||
5, int(self.vision_model.subscribed_max_prompt_size / self.context_compress_trigger)
|
||||
)
|
||||
# compression ratio determines how many messages to compress down to one
|
||||
# e.g. if 5 messages, a compress ratio of 4/5 means compress 5 messages into 1 + keep 1 uncompressed
|
||||
self.message_compress_ratio = 4 / 5
|
||||
self.compress_length = int(self.message_limit * self.message_compress_ratio)
|
||||
|
||||
@abstractmethod
|
||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user