From 675fc0ad05fab5c087f6027cf5301d8620a5ac7a Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 28 May 2025 16:06:24 -0700 Subject: [PATCH] Decouple trajectory compression from `act'. Reuse func to call llm api --- .../operator/operator_agent_anthropic.py | 163 ++++++++++++------ 1 file changed, 109 insertions(+), 54 deletions(-) diff --git a/src/khoj/processor/operator/operator_agent_anthropic.py b/src/khoj/processor/operator/operator_agent_anthropic.py index 2e7cfe97..567411f6 100644 --- a/src/khoj/processor/operator/operator_agent_anthropic.py +++ b/src/khoj/processor/operator/operator_agent_anthropic.py @@ -8,6 +8,7 @@ from typing import List, Literal, Optional, cast from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlock +from khoj.database.models import ChatModel from khoj.processor.conversation.anthropic.utils import is_reasoning_model from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_agent_base import ( @@ -28,11 +29,6 @@ logger = logging.getLogger(__name__) # --- Anthropic Operator Agent --- class AnthropicOperatorAgent(OperatorAgent): async def act(self, current_state: EnvState) -> AgentActResult: - client = get_anthropic_async_client( - self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url - ) - betas = self.model_default_headers() - temperature = 1.0 actions: List[OperatorAction] = [] action_results: List[dict] = [] self._commit_trace() # Commit trace before next action @@ -43,53 +39,23 @@ class AnthropicOperatorAgent(OperatorAgent): if is_none_or_empty(self.messages): self.messages = [AgentMessage(role="user", content=self.query)] - thinking: dict[str, str | int] = {"type": "disabled"} - if is_reasoning_model(self.vision_model.name): - thinking = {"type": "enabled", "budget_tokens": 1024} - # Trigger trajectory compression if exceed size limit if len(self.messages) > self.message_limit: - # 1. Prepare messages for compression - original_messages = self.messages - self.messages = self.messages[: self.compress_length] - # ensure last message isn't a tool call request - if self.messages[-1].role == "assistant" and any( - isinstance(block, BetaToolUseBlock) for block in self.messages[-1].content - ): - self.messages.pop() - # 2. Get summary of operation trajectory - await self.summarize(current_state) - # 3. Rebuild condensed trajectory - primary_task = [original_messages.pop(0)] - condensed_trajectory = self.messages[-2:] # extract summary request, response - recent_trajectory = original_messages[self.compress_length :] - self.messages = primary_task + condensed_trajectory + recent_trajectory + logger.debug("Compacting operator trajectory.") + await self._compress() - messages_for_api = self._format_message_for_api(self.messages) - try: - response = await client.beta.messages.create( - messages=messages_for_api, - model=self.vision_model.name, - system=system_prompt, - tools=tools, - betas=betas, - thinking=thinking, - max_tokens=4096, # TODO: Make configurable? - temperature=temperature, - ) - response_content = response.content - except Exception as e: - # create a response block with error message - logger.error(f"Error during Anthropic API call: {e}") - error_str = e.message if hasattr(e, "message") else str(e) - response = None - response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")] - else: - logger.debug(f"Anthropic response: {response.model_dump_json()}") + response_content = await self._call_model( + messages=self.messages, + model=self.vision_model, + system_prompt=system_prompt, + tools=tools, + headers=self.model_default_headers(), + ) self.messages.append(AgentMessage(role="assistant", content=response_content)) rendered_response = self._render_response(response_content, current_state.screenshot) + # Parse actions from response for block in response_content: if block.type == "tool_use": content = None @@ -193,15 +159,6 @@ class AnthropicOperatorAgent(OperatorAgent): } ) - if response: - self._update_usage( - response.usage.input_tokens, - response.usage.output_tokens, - response.usage.cache_read_input_tokens, - response.usage.cache_creation_input_tokens, - ) - self.tracer["temperature"] = temperature - return AgentActResult( actions=actions, action_results=action_results, @@ -360,6 +317,104 @@ class AnthropicOperatorAgent(OperatorAgent): return render_payload + async def _call_model( + self, + messages: list[AgentMessage], + model: ChatModel, + system_prompt: str, + tools: list[dict] = [], + headers: list[str] = [], + temperature: float = 1.0, + max_tokens: int = 4096, + ) -> list[BetaContentBlock]: + client = get_anthropic_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url) + kwargs = {} + thinking: dict[str, str | int] = {"type": "disabled"} + if is_reasoning_model(model.name): + thinking = {"type": "enabled", "budget_tokens": 1024} + if headers: + kwargs["betas"] = headers + if tools: + kwargs["tools"] = tools + + messages_for_api = self._format_message_for_api(messages) + try: + response = await client.beta.messages.create( + messages=messages_for_api, + model=model.name, + system=system_prompt, + thinking=thinking, + max_tokens=max_tokens, + temperature=temperature, + **kwargs, + ) + response_content = response.content + except Exception as e: + # create a response block with error message + logger.error(f"Error during Anthropic API call: {e}") + error_str = e.message if hasattr(e, "message") else str(e) + response = None + response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")] + + if response: + logger.debug(f"Anthropic response: {response.model_dump_json()}") + self._update_usage( + response.usage.input_tokens, + response.usage.output_tokens, + response.usage.cache_read_input_tokens, + response.usage.cache_creation_input_tokens, + ) + self.tracer["temperature"] = temperature + return response_content + + async def _compress(self): + # 1. Prepare messages for compression + original_messages = list(self.messages) + messages_to_summarize = self.messages[: self.compress_length] + # ensure last message isn't a tool call request + if messages_to_summarize[-1].role == "assistant" and any( + isinstance(block, BetaToolUseBlock) for block in messages_to_summarize[-1].content + ): + messages_to_summarize.pop() + + summarize_prompt = f"Summarize your research and computer use till now to help answer my query:\n{self.query}" + summarize_message = AgentMessage(role="user", content=summarize_prompt) + system_prompt = dedent( + """ + You are a computer operator with meticulous communication skills. You can condense your partial computer use traces and research into an appropriately detailed summary. + When requested summarize your key actions, results and findings until now to achieve the user specified task. + Your summary should help you remember the key information required to both complete the task and later generate a final report. + """ + ) + + # 2. Get summary of operation trajectory + try: + response_content = await self._call_model( + messages=messages_to_summarize + [summarize_message], + model=self.vision_model, + system_prompt=system_prompt, + max_tokens=8192, + ) + except Exception as e: + # create a response block with error message + logger.error(f"Error during Anthropic API call: {e}") + error_str = e.message if hasattr(e, "message") else str(e) + response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")] + + summary_message = AgentMessage(role="assistant", content=response_content) + + # 3. Rebuild message history with condensed trajectory + primary_task = [original_messages.pop(0)] + condensed_trajectory = [summarize_message, summary_message] + recent_trajectory = original_messages[self.compress_length - 1 :] # -1 since we popped the first message + # ensure first message isn't a tool result + if recent_trajectory[0].role == "environment" and any( + block["type"] == "tool_result" for block in recent_trajectory[0].content + ): + recent_trajectory.pop(0) + + self.messages = primary_task + condensed_trajectory + recent_trajectory + def get_coordinates(self, tool_input: dict, key: str = "coordinate") -> Optional[list | tuple]: """Get coordinates from tool input.""" raw_coord = tool_input.get(key)