mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Decouple trajectory compression from `act'. Reuse func to call llm api
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import List, Literal, Optional, cast
|
||||
|
||||
from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlock
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.anthropic.utils import is_reasoning_model
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_agent_base import (
|
||||
@@ -28,11 +29,6 @@ logger = logging.getLogger(__name__)
|
||||
# --- Anthropic Operator Agent ---
|
||||
class AnthropicOperatorAgent(OperatorAgent):
|
||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||
client = get_anthropic_async_client(
|
||||
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
|
||||
)
|
||||
betas = self.model_default_headers()
|
||||
temperature = 1.0
|
||||
actions: List[OperatorAction] = []
|
||||
action_results: List[dict] = []
|
||||
self._commit_trace() # Commit trace before next action
|
||||
@@ -43,53 +39,23 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
if is_none_or_empty(self.messages):
|
||||
self.messages = [AgentMessage(role="user", content=self.query)]
|
||||
|
||||
thinking: dict[str, str | int] = {"type": "disabled"}
|
||||
if is_reasoning_model(self.vision_model.name):
|
||||
thinking = {"type": "enabled", "budget_tokens": 1024}
|
||||
|
||||
# Trigger trajectory compression if exceed size limit
|
||||
if len(self.messages) > self.message_limit:
|
||||
# 1. Prepare messages for compression
|
||||
original_messages = self.messages
|
||||
self.messages = self.messages[: self.compress_length]
|
||||
# ensure last message isn't a tool call request
|
||||
if self.messages[-1].role == "assistant" and any(
|
||||
isinstance(block, BetaToolUseBlock) for block in self.messages[-1].content
|
||||
):
|
||||
self.messages.pop()
|
||||
# 2. Get summary of operation trajectory
|
||||
await self.summarize(current_state)
|
||||
# 3. Rebuild condensed trajectory
|
||||
primary_task = [original_messages.pop(0)]
|
||||
condensed_trajectory = self.messages[-2:] # extract summary request, response
|
||||
recent_trajectory = original_messages[self.compress_length :]
|
||||
self.messages = primary_task + condensed_trajectory + recent_trajectory
|
||||
logger.debug("Compacting operator trajectory.")
|
||||
await self._compress()
|
||||
|
||||
messages_for_api = self._format_message_for_api(self.messages)
|
||||
try:
|
||||
response = await client.beta.messages.create(
|
||||
messages=messages_for_api,
|
||||
model=self.vision_model.name,
|
||||
system=system_prompt,
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
thinking=thinking,
|
||||
max_tokens=4096, # TODO: Make configurable?
|
||||
temperature=temperature,
|
||||
)
|
||||
response_content = response.content
|
||||
except Exception as e:
|
||||
# create a response block with error message
|
||||
logger.error(f"Error during Anthropic API call: {e}")
|
||||
error_str = e.message if hasattr(e, "message") else str(e)
|
||||
response = None
|
||||
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
|
||||
else:
|
||||
logger.debug(f"Anthropic response: {response.model_dump_json()}")
|
||||
response_content = await self._call_model(
|
||||
messages=self.messages,
|
||||
model=self.vision_model,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
headers=self.model_default_headers(),
|
||||
)
|
||||
|
||||
self.messages.append(AgentMessage(role="assistant", content=response_content))
|
||||
rendered_response = self._render_response(response_content, current_state.screenshot)
|
||||
|
||||
# Parse actions from response
|
||||
for block in response_content:
|
||||
if block.type == "tool_use":
|
||||
content = None
|
||||
@@ -193,15 +159,6 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
}
|
||||
)
|
||||
|
||||
if response:
|
||||
self._update_usage(
|
||||
response.usage.input_tokens,
|
||||
response.usage.output_tokens,
|
||||
response.usage.cache_read_input_tokens,
|
||||
response.usage.cache_creation_input_tokens,
|
||||
)
|
||||
self.tracer["temperature"] = temperature
|
||||
|
||||
return AgentActResult(
|
||||
actions=actions,
|
||||
action_results=action_results,
|
||||
@@ -360,6 +317,104 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
|
||||
return render_payload
|
||||
|
||||
async def _call_model(
|
||||
self,
|
||||
messages: list[AgentMessage],
|
||||
model: ChatModel,
|
||||
system_prompt: str,
|
||||
tools: list[dict] = [],
|
||||
headers: list[str] = [],
|
||||
temperature: float = 1.0,
|
||||
max_tokens: int = 4096,
|
||||
) -> list[BetaContentBlock]:
|
||||
client = get_anthropic_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url)
|
||||
kwargs = {}
|
||||
thinking: dict[str, str | int] = {"type": "disabled"}
|
||||
if is_reasoning_model(model.name):
|
||||
thinking = {"type": "enabled", "budget_tokens": 1024}
|
||||
if headers:
|
||||
kwargs["betas"] = headers
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
|
||||
messages_for_api = self._format_message_for_api(messages)
|
||||
try:
|
||||
response = await client.beta.messages.create(
|
||||
messages=messages_for_api,
|
||||
model=model.name,
|
||||
system=system_prompt,
|
||||
thinking=thinking,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
response_content = response.content
|
||||
except Exception as e:
|
||||
# create a response block with error message
|
||||
logger.error(f"Error during Anthropic API call: {e}")
|
||||
error_str = e.message if hasattr(e, "message") else str(e)
|
||||
response = None
|
||||
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
|
||||
|
||||
if response:
|
||||
logger.debug(f"Anthropic response: {response.model_dump_json()}")
|
||||
self._update_usage(
|
||||
response.usage.input_tokens,
|
||||
response.usage.output_tokens,
|
||||
response.usage.cache_read_input_tokens,
|
||||
response.usage.cache_creation_input_tokens,
|
||||
)
|
||||
self.tracer["temperature"] = temperature
|
||||
return response_content
|
||||
|
||||
async def _compress(self):
|
||||
# 1. Prepare messages for compression
|
||||
original_messages = list(self.messages)
|
||||
messages_to_summarize = self.messages[: self.compress_length]
|
||||
# ensure last message isn't a tool call request
|
||||
if messages_to_summarize[-1].role == "assistant" and any(
|
||||
isinstance(block, BetaToolUseBlock) for block in messages_to_summarize[-1].content
|
||||
):
|
||||
messages_to_summarize.pop()
|
||||
|
||||
summarize_prompt = f"Summarize your research and computer use till now to help answer my query:\n{self.query}"
|
||||
summarize_message = AgentMessage(role="user", content=summarize_prompt)
|
||||
system_prompt = dedent(
|
||||
"""
|
||||
You are a computer operator with meticulous communication skills. You can condense your partial computer use traces and research into an appropriately detailed summary.
|
||||
When requested summarize your key actions, results and findings until now to achieve the user specified task.
|
||||
Your summary should help you remember the key information required to both complete the task and later generate a final report.
|
||||
"""
|
||||
)
|
||||
|
||||
# 2. Get summary of operation trajectory
|
||||
try:
|
||||
response_content = await self._call_model(
|
||||
messages=messages_to_summarize + [summarize_message],
|
||||
model=self.vision_model,
|
||||
system_prompt=system_prompt,
|
||||
max_tokens=8192,
|
||||
)
|
||||
except Exception as e:
|
||||
# create a response block with error message
|
||||
logger.error(f"Error during Anthropic API call: {e}")
|
||||
error_str = e.message if hasattr(e, "message") else str(e)
|
||||
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
|
||||
|
||||
summary_message = AgentMessage(role="assistant", content=response_content)
|
||||
|
||||
# 3. Rebuild message history with condensed trajectory
|
||||
primary_task = [original_messages.pop(0)]
|
||||
condensed_trajectory = [summarize_message, summary_message]
|
||||
recent_trajectory = original_messages[self.compress_length - 1 :] # -1 since we popped the first message
|
||||
# ensure first message isn't a tool result
|
||||
if recent_trajectory[0].role == "environment" and any(
|
||||
block["type"] == "tool_result" for block in recent_trajectory[0].content
|
||||
):
|
||||
recent_trajectory.pop(0)
|
||||
|
||||
self.messages = primary_task + condensed_trajectory + recent_trajectory
|
||||
|
||||
def get_coordinates(self, tool_input: dict, key: str = "coordinate") -> Optional[list | tuple]:
|
||||
"""Get coordinates from tool input."""
|
||||
raw_coord = tool_input.get(key)
|
||||
|
||||
Reference in New Issue
Block a user