mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 05:39:12 +00:00
Decouple trajectory compression from `act'. Reuse func to call llm api
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import List, Literal, Optional, cast
|
|||||||
|
|
||||||
from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlock
|
from anthropic.types.beta import BetaContentBlock, BetaTextBlock, BetaToolUseBlock
|
||||||
|
|
||||||
|
from khoj.database.models import ChatModel
|
||||||
from khoj.processor.conversation.anthropic.utils import is_reasoning_model
|
from khoj.processor.conversation.anthropic.utils import is_reasoning_model
|
||||||
from khoj.processor.operator.operator_actions import *
|
from khoj.processor.operator.operator_actions import *
|
||||||
from khoj.processor.operator.operator_agent_base import (
|
from khoj.processor.operator.operator_agent_base import (
|
||||||
@@ -28,11 +29,6 @@ logger = logging.getLogger(__name__)
|
|||||||
# --- Anthropic Operator Agent ---
|
# --- Anthropic Operator Agent ---
|
||||||
class AnthropicOperatorAgent(OperatorAgent):
|
class AnthropicOperatorAgent(OperatorAgent):
|
||||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||||
client = get_anthropic_async_client(
|
|
||||||
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
|
|
||||||
)
|
|
||||||
betas = self.model_default_headers()
|
|
||||||
temperature = 1.0
|
|
||||||
actions: List[OperatorAction] = []
|
actions: List[OperatorAction] = []
|
||||||
action_results: List[dict] = []
|
action_results: List[dict] = []
|
||||||
self._commit_trace() # Commit trace before next action
|
self._commit_trace() # Commit trace before next action
|
||||||
@@ -43,53 +39,23 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
if is_none_or_empty(self.messages):
|
if is_none_or_empty(self.messages):
|
||||||
self.messages = [AgentMessage(role="user", content=self.query)]
|
self.messages = [AgentMessage(role="user", content=self.query)]
|
||||||
|
|
||||||
thinking: dict[str, str | int] = {"type": "disabled"}
|
|
||||||
if is_reasoning_model(self.vision_model.name):
|
|
||||||
thinking = {"type": "enabled", "budget_tokens": 1024}
|
|
||||||
|
|
||||||
# Trigger trajectory compression if exceed size limit
|
# Trigger trajectory compression if exceed size limit
|
||||||
if len(self.messages) > self.message_limit:
|
if len(self.messages) > self.message_limit:
|
||||||
# 1. Prepare messages for compression
|
logger.debug("Compacting operator trajectory.")
|
||||||
original_messages = self.messages
|
await self._compress()
|
||||||
self.messages = self.messages[: self.compress_length]
|
|
||||||
# ensure last message isn't a tool call request
|
|
||||||
if self.messages[-1].role == "assistant" and any(
|
|
||||||
isinstance(block, BetaToolUseBlock) for block in self.messages[-1].content
|
|
||||||
):
|
|
||||||
self.messages.pop()
|
|
||||||
# 2. Get summary of operation trajectory
|
|
||||||
await self.summarize(current_state)
|
|
||||||
# 3. Rebuild condensed trajectory
|
|
||||||
primary_task = [original_messages.pop(0)]
|
|
||||||
condensed_trajectory = self.messages[-2:] # extract summary request, response
|
|
||||||
recent_trajectory = original_messages[self.compress_length :]
|
|
||||||
self.messages = primary_task + condensed_trajectory + recent_trajectory
|
|
||||||
|
|
||||||
messages_for_api = self._format_message_for_api(self.messages)
|
response_content = await self._call_model(
|
||||||
try:
|
messages=self.messages,
|
||||||
response = await client.beta.messages.create(
|
model=self.vision_model,
|
||||||
messages=messages_for_api,
|
system_prompt=system_prompt,
|
||||||
model=self.vision_model.name,
|
tools=tools,
|
||||||
system=system_prompt,
|
headers=self.model_default_headers(),
|
||||||
tools=tools,
|
)
|
||||||
betas=betas,
|
|
||||||
thinking=thinking,
|
|
||||||
max_tokens=4096, # TODO: Make configurable?
|
|
||||||
temperature=temperature,
|
|
||||||
)
|
|
||||||
response_content = response.content
|
|
||||||
except Exception as e:
|
|
||||||
# create a response block with error message
|
|
||||||
logger.error(f"Error during Anthropic API call: {e}")
|
|
||||||
error_str = e.message if hasattr(e, "message") else str(e)
|
|
||||||
response = None
|
|
||||||
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
|
|
||||||
else:
|
|
||||||
logger.debug(f"Anthropic response: {response.model_dump_json()}")
|
|
||||||
|
|
||||||
self.messages.append(AgentMessage(role="assistant", content=response_content))
|
self.messages.append(AgentMessage(role="assistant", content=response_content))
|
||||||
rendered_response = self._render_response(response_content, current_state.screenshot)
|
rendered_response = self._render_response(response_content, current_state.screenshot)
|
||||||
|
|
||||||
|
# Parse actions from response
|
||||||
for block in response_content:
|
for block in response_content:
|
||||||
if block.type == "tool_use":
|
if block.type == "tool_use":
|
||||||
content = None
|
content = None
|
||||||
@@ -193,15 +159,6 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if response:
|
|
||||||
self._update_usage(
|
|
||||||
response.usage.input_tokens,
|
|
||||||
response.usage.output_tokens,
|
|
||||||
response.usage.cache_read_input_tokens,
|
|
||||||
response.usage.cache_creation_input_tokens,
|
|
||||||
)
|
|
||||||
self.tracer["temperature"] = temperature
|
|
||||||
|
|
||||||
return AgentActResult(
|
return AgentActResult(
|
||||||
actions=actions,
|
actions=actions,
|
||||||
action_results=action_results,
|
action_results=action_results,
|
||||||
@@ -360,6 +317,104 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
|
|
||||||
return render_payload
|
return render_payload
|
||||||
|
|
||||||
|
async def _call_model(
|
||||||
|
self,
|
||||||
|
messages: list[AgentMessage],
|
||||||
|
model: ChatModel,
|
||||||
|
system_prompt: str,
|
||||||
|
tools: list[dict] = [],
|
||||||
|
headers: list[str] = [],
|
||||||
|
temperature: float = 1.0,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
) -> list[BetaContentBlock]:
|
||||||
|
client = get_anthropic_async_client(model.ai_model_api.api_key, model.ai_model_api.api_base_url)
|
||||||
|
kwargs = {}
|
||||||
|
thinking: dict[str, str | int] = {"type": "disabled"}
|
||||||
|
if is_reasoning_model(model.name):
|
||||||
|
thinking = {"type": "enabled", "budget_tokens": 1024}
|
||||||
|
if headers:
|
||||||
|
kwargs["betas"] = headers
|
||||||
|
if tools:
|
||||||
|
kwargs["tools"] = tools
|
||||||
|
|
||||||
|
messages_for_api = self._format_message_for_api(messages)
|
||||||
|
try:
|
||||||
|
response = await client.beta.messages.create(
|
||||||
|
messages=messages_for_api,
|
||||||
|
model=model.name,
|
||||||
|
system=system_prompt,
|
||||||
|
thinking=thinking,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
response_content = response.content
|
||||||
|
except Exception as e:
|
||||||
|
# create a response block with error message
|
||||||
|
logger.error(f"Error during Anthropic API call: {e}")
|
||||||
|
error_str = e.message if hasattr(e, "message") else str(e)
|
||||||
|
response = None
|
||||||
|
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
|
||||||
|
|
||||||
|
if response:
|
||||||
|
logger.debug(f"Anthropic response: {response.model_dump_json()}")
|
||||||
|
self._update_usage(
|
||||||
|
response.usage.input_tokens,
|
||||||
|
response.usage.output_tokens,
|
||||||
|
response.usage.cache_read_input_tokens,
|
||||||
|
response.usage.cache_creation_input_tokens,
|
||||||
|
)
|
||||||
|
self.tracer["temperature"] = temperature
|
||||||
|
return response_content
|
||||||
|
|
||||||
|
async def _compress(self):
|
||||||
|
# 1. Prepare messages for compression
|
||||||
|
original_messages = list(self.messages)
|
||||||
|
messages_to_summarize = self.messages[: self.compress_length]
|
||||||
|
# ensure last message isn't a tool call request
|
||||||
|
if messages_to_summarize[-1].role == "assistant" and any(
|
||||||
|
isinstance(block, BetaToolUseBlock) for block in messages_to_summarize[-1].content
|
||||||
|
):
|
||||||
|
messages_to_summarize.pop()
|
||||||
|
|
||||||
|
summarize_prompt = f"Summarize your research and computer use till now to help answer my query:\n{self.query}"
|
||||||
|
summarize_message = AgentMessage(role="user", content=summarize_prompt)
|
||||||
|
system_prompt = dedent(
|
||||||
|
"""
|
||||||
|
You are a computer operator with meticulous communication skills. You can condense your partial computer use traces and research into an appropriately detailed summary.
|
||||||
|
When requested summarize your key actions, results and findings until now to achieve the user specified task.
|
||||||
|
Your summary should help you remember the key information required to both complete the task and later generate a final report.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Get summary of operation trajectory
|
||||||
|
try:
|
||||||
|
response_content = await self._call_model(
|
||||||
|
messages=messages_to_summarize + [summarize_message],
|
||||||
|
model=self.vision_model,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
max_tokens=8192,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# create a response block with error message
|
||||||
|
logger.error(f"Error during Anthropic API call: {e}")
|
||||||
|
error_str = e.message if hasattr(e, "message") else str(e)
|
||||||
|
response_content = [BetaTextBlock(text=f"Communication Error: {error_str}", type="text")]
|
||||||
|
|
||||||
|
summary_message = AgentMessage(role="assistant", content=response_content)
|
||||||
|
|
||||||
|
# 3. Rebuild message history with condensed trajectory
|
||||||
|
primary_task = [original_messages.pop(0)]
|
||||||
|
condensed_trajectory = [summarize_message, summary_message]
|
||||||
|
recent_trajectory = original_messages[self.compress_length - 1 :] # -1 since we popped the first message
|
||||||
|
# ensure first message isn't a tool result
|
||||||
|
if recent_trajectory[0].role == "environment" and any(
|
||||||
|
block["type"] == "tool_result" for block in recent_trajectory[0].content
|
||||||
|
):
|
||||||
|
recent_trajectory.pop(0)
|
||||||
|
|
||||||
|
self.messages = primary_task + condensed_trajectory + recent_trajectory
|
||||||
|
|
||||||
def get_coordinates(self, tool_input: dict, key: str = "coordinate") -> Optional[list | tuple]:
|
def get_coordinates(self, tool_input: dict, key: str = "coordinate") -> Optional[list | tuple]:
|
||||||
"""Get coordinates from tool input."""
|
"""Get coordinates from tool input."""
|
||||||
raw_coord = tool_input.get(key)
|
raw_coord = tool_input.get(key)
|
||||||
|
|||||||
Reference in New Issue
Block a user