Decouple trajectory compression from `act'. Reuse func to call llm api

This commit is contained in:
Debanjum
2025-05-28 16:06:24 -07:00
parent b027024c42
commit 675fc0ad05

View File

@@ -8,6 +8,7 @@ from typing import List, Literal, Optional, cast
from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlock
from khoj.database.models import ChatModel
from khoj.processor.conversation.anthropic.utils import is_reasoning_model
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import (
@@ -28,11 +29,6 @@ logger = logging.getLogger(__name__)
# --- Anthropic Operator Agent ---
class AnthropicOperatorAgent(OperatorAgent):
async def act(self, current_state: EnvState) -> AgentActResult:
client = get_anthropic_async_client(
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
)
betas = self.model_default_headers()
temperature = 1.0
actions: List[OperatorAction] = []
action_results: List[dict] = []
self._commit_trace() # Commit trace before next action
@@ -43,53 +39,23 @@ class AnthropicOperatorAgent(OperatorAgent):
if is_none_or_empty(self.messages):
self.messages = [AgentMessage(role="user", content=self.query)]
thinking: dict[str, str | int] = {"type": "disabled"}
if is_reasoning_model(self.vision_model.name):
thinking = {"type": "enabled", "budget_tokens": 1024}
# Trigger trajectory compression if exceed size limit
if len(self.messages) > self.message_limit:
# 1. Prepare messages for compression
original_messages = self.messages
self.messages = self.messages[: self.compress_length]
# ensure last message isn't a tool call request
if self.messages[-1].role == "assistant" and any(
isinstance(block, BetaToolUseBlock) for block in self.messages[-1].content
):
self.messages.pop()
# 2. Get summary of operation trajectory
await self.summarize(current_state)
# 3. Rebuild condensed trajectory
primary_task = [original_messages.pop(0)]
condensed_trajectory = self.messages[-2:] # extract summary request, response
recent_trajectory = original_messages[self.compress_length :]
self.messages = primary_task + condensed_trajectory + recent_trajectory
logger.debug("Compacting operator trajectory.")
await self._compress()
messages_for_api = self._format_message_for_api(self.messages)
try:
response = await client.beta.messages.create(
messages=messages_for_api,
model=self.vision_model.name,
system=system_prompt,
tools=tools,
betas=betas,
thinking=thinking,
max_tokens=4096, # TODO: Make configurable?
temperature=temperature,
)
response_content = response.content
except Exception as e:
# create a response block with error message
logger.error(f"Error during Anthropic API call: {e}")
error_str = e.message if hasattr(e, "message") else str(e)
response = None
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
else:
logger.debug(f"Anthropic response: {response.model_dump_json()}")
response_content = await self._call_model(
messages=self.messages,
model=self.vision_model,
system_prompt=system_prompt,
tools=tools,
headers=self.model_default_headers(),
)
self.messages.append(AgentMessage(role="assistant", content=response_content))
rendered_response = self._render_response(response_content, current_state.screenshot)
# Parse actions from response
for block in response_content:
if block.type == "tool_use":
content = None
@@ -193,15 +159,6 @@ class AnthropicOperatorAgent(OperatorAgent):
}
)
if response:
self._update_usage(
response.usage.input_tokens,
response.usage.output_tokens,
response.usage.cache_read_input_tokens,
response.usage.cache_creation_input_tokens,
)
self.tracer["temperature"] = temperature
return AgentActResult(
actions=actions,
action_results=action_results,
@@ -360,6 +317,104 @@ class AnthropicOperatorAgent(OperatorAgent):
return render_payload
async def _call_model(
self,
messages: list[AgentMessage],
model: ChatModel,
system_prompt: str,
tools: list[dict] = [],
headers: list[str] = [],
temperature: float = 1.0,
max_tokens: int = 4096,
) -> list[BetaContentBlock]:
client = get_anthropic_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url)
kwargs = {}
thinking: dict[str, str | int] = {"type": "disabled"}
if is_reasoning_model(model.name):
thinking = {"type": "enabled", "budget_tokens": 1024}
if headers:
kwargs["betas"] = headers
if tools:
kwargs["tools"] = tools
messages_for_api = self._format_message_for_api(messages)
try:
response = await client.beta.messages.create(
messages=messages_for_api,
model=model.name,
system=system_prompt,
thinking=thinking,
max_tokens=max_tokens,
temperature=temperature,
**kwargs,
)
response_content = response.content
except Exception as e:
# create a response block with error message
logger.error(f"Error during Anthropic API call: {e}")
error_str = e.message if hasattr(e, "message") else str(e)
response = None
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
if response:
logger.debug(f"Anthropic response: {response.model_dump_json()}")
self._update_usage(
response.usage.input_tokens,
response.usage.output_tokens,
response.usage.cache_read_input_tokens,
response.usage.cache_creation_input_tokens,
)
self.tracer["temperature"] = temperature
return response_content
async def _compress(self):
# 1. Prepare messages for compression
original_messages = list(self.messages)
messages_to_summarize = self.messages[: self.compress_length]
# ensure last message isn't a tool call request
if messages_to_summarize[-1].role == "assistant" and any(
isinstance(block, BetaToolUseBlock) for block in messages_to_summarize[-1].content
):
messages_to_summarize.pop()
summarize_prompt = f"Summarize your research and computer use till now to help answer my query:\n{self.query}"
summarize_message = AgentMessage(role="user", content=summarize_prompt)
system_prompt = dedent(
"""
You are a computer operator with meticulous communication skills. You can condense your partial computer use traces and research into an appropriately detailed summary.
When requested summarize your key actions, results and findings until now to achieve the user specified task.
Your summary should help you remember the key information required to both complete the task and later generate a final report.
"""
)
# 2. Get summary of operation trajectory
try:
response_content = await self._call_model(
messages=messages_to_summarize + [summarize_message],
model=self.vision_model,
system_prompt=system_prompt,
max_tokens=8192,
)
except Exception as e:
# create a response block with error message
logger.error(f"Error during Anthropic API call: {e}")
error_str = e.message if hasattr(e, "message") else str(e)
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
summary_message = AgentMessage(role="assistant", content=response_content)
# 3. Rebuild message history with condensed trajectory
primary_task = [original_messages.pop(0)]
condensed_trajectory = [summarize_message, summary_message]
recent_trajectory = original_messages[self.compress_length - 1 :] # -1 since we popped the first message
# ensure first message isn't a tool result
if recent_trajectory[0].role == "environment" and any(
block["type"] == "tool_result" for block in recent_trajectory[0].content
):
recent_trajectory.pop(0)
self.messages = primary_task + condensed_trajectory + recent_trajectory
def get_coordinates(self, tool_input: dict, key: str = "coordinate") -> Optional[list | tuple]:
"""Get coordinates from tool input."""
raw_coord = tool_input.get(key)