Simplify operator loop. Make each OperatorAgent manage state internally.

Remove each OperatorAgent specific code from leaking out into the
operator. The Oprator just calls the standard OperatorAgent functions.

Each AgentOperator specific logic is handled by the OperatorAgent
internally.

The improve the separation of responsibility between the operator,
OperatorAgent and the Environment.

- Make environment pass screenshot data in agent agnostic format
  - Have operator agents providers format image data to their AI model
    specific format
  - Add environment step type to distinguish image vs text content
- Clearly mark major steps in the operator iteration loop
- Handle anthropic models returning computer tool actions as normal
  tool calls by normalizing next action retrieval from response for it
- Remove unused ActionResults fields
- Remove unnnecessary placeholders to content of action results like
  for screenshot data
This commit is contained in:
Debanjum
2025-05-04 18:39:12 -06:00
parent a1c9c6b2e3
commit 4db888cd62

View File

@@ -6,11 +6,10 @@ import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from typing import Callable, List, Literal, Optional, Set, Union from typing import Any, Callable, List, Literal, Optional, Set, Union
import requests import requests
from anthropic.types.beta import BetaContentBlock, BetaMessage from anthropic.types.beta import BetaContentBlock, BetaMessage
from langchain.schema import ChatMessage
from openai.types.responses import Response, ResponseOutputItem from openai.types.responses import Response, ResponseOutputItem
from playwright.async_api import Browser, Page, Playwright, async_playwright from playwright.async_api import Browser, Page, Playwright, async_playwright
from pydantic import BaseModel from pydantic import BaseModel
@@ -24,6 +23,7 @@ from khoj.utils.helpers import (
get_anthropic_async_client, get_anthropic_async_client,
get_chat_usage_metrics, get_chat_usage_metrics,
get_openai_async_client, get_openai_async_client,
is_none_or_empty,
is_promptrace_enabled, is_promptrace_enabled,
timer, timer,
) )
@@ -69,7 +69,7 @@ class ScrollAction(BaseAction):
scroll_x: Optional[int] = None scroll_x: Optional[int] = None
scroll_y: Optional[int] = None scroll_y: Optional[int] = None
scroll_direction: Optional[Literal["up", "down", "left", "right"]] = None scroll_direction: Optional[Literal["up", "down", "left", "right"]] = None
scroll_amount: Optional[int] = 1 scroll_amount: Optional[int] = 5
class KeypressAction(BaseAction): class KeypressAction(BaseAction):
@@ -131,6 +131,13 @@ class BackAction(BaseAction):
type: Literal["back"] = "back" type: Literal["back"] = "back"
class RequestUserAction(BaseAction):
"""Request user action to confirm or provide input."""
type: Literal["request_user"] = "request_user"
request: str
BrowserAction = Union[ BrowserAction = Union[
ClickAction, ClickAction,
DoubleClickAction, DoubleClickAction,
@@ -156,20 +163,23 @@ class EnvState(BaseModel):
screenshot: Optional[str] = None screenshot: Optional[str] = None
class StepResult(BaseModel): class EnvStepResult(BaseModel):
output: Optional[str | List[dict]] = None # Output message or screenshot data type: Literal["text", "image"] = "text"
output: Optional[str | dict] = None
error: Optional[str] = None error: Optional[str] = None
current_url: Optional[str] = None current_url: Optional[str] = None
screenshot_base64: Optional[str] = None screenshot_base64: Optional[str] = None
class AgentActResult(BaseModel): class AgentActResult(BaseModel):
compiled_response: str
actions: List[BrowserAction] = [] actions: List[BrowserAction] = []
tool_results_for_history: List[dict] = [] # Model-specific format for history action_results: List[dict] = [] # Model-specific format
raw_agent_response: BetaMessage | Response # Store the raw response for history formatting rendered_response: Optional[str] = None
usage: dict = {}
safety_check_message: Optional[str] = None
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "environment"]
content: Union[str, List]
# --- Abstract Classes --- # --- Abstract Classes ---
@@ -179,7 +189,7 @@ class Environment(ABC):
pass pass
@abstractmethod @abstractmethod
async def step(self, action: BrowserAction) -> StepResult: async def step(self, action: BrowserAction) -> EnvStepResult:
pass pass
@abstractmethod @abstractmethod
@@ -196,10 +206,36 @@ class OperatorAgent(ABC):
self.chat_model = chat_model self.chat_model = chat_model
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.tracer = tracer self.tracer = tracer
self.compiled_operator_messages: List[ChatMessage] = [] self.messages: List[ChatMessage] = []
@abstractmethod @abstractmethod
async def act(self, messages: List[dict], current_state: EnvState) -> AgentActResult: async def act(self, query: str, current_state: EnvState) -> AgentActResult:
pass
@abstractmethod
def add_action_results(
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
) -> None:
"""Track results of agent actions on the environment."""
pass
async def summarize(self, query: str, current_state: EnvState) -> str:
"""Summarize the agent's actions and results."""
await self.act(query, current_state)
if not self.messages:
return "No actions to summarize."
return await self.compile_response(self.messages[-1].content)
@abstractmethod
def compile_response(self, response: List) -> str:
pass
@abstractmethod
def _render_response(self, response: List, screenshot: Optional[str]) -> Optional[str]:
pass
@abstractmethod
def _format_message_for_api(self, message: ChatMessage) -> List:
pass pass
def _update_usage(self, input_tokens: int, output_tokens: int, cache_read: int = 0, cache_write: int = 0): def _update_usage(self, input_tokens: int, output_tokens: int, cache_read: int = 0, cache_write: int = 0):
@@ -210,10 +246,15 @@ class OperatorAgent(ABC):
def _commit_trace(self): def _commit_trace(self):
self.tracer["chat_model"] = self.chat_model.name self.tracer["chat_model"] = self.chat_model.name
if is_promptrace_enabled() and self.compiled_operator_messages: if is_promptrace_enabled() and len(self.messages) > 1:
commit_conversation_trace( compiled_messages = [
self.compiled_operator_messages[:-1], self.compiled_operator_messages[-1].content, self.tracer ChatMessage(role=msg.role, content=self.compile_response(msg.content)) for msg in self.messages
) ]
commit_conversation_trace(compiled_messages[:-1], compiled_messages[-1].content, self.tracer)
def reset(self):
"""Reset the agent state."""
self.messages = []
# --- Concrete BrowserEnvironment --- # --- Concrete BrowserEnvironment ---
@@ -293,11 +334,12 @@ class BrowserEnvironment(Environment):
screenshot = await self._get_screenshot() screenshot = await self._get_screenshot()
return EnvState(url=url, screenshot=screenshot) return EnvState(url=url, screenshot=screenshot)
async def step(self, action: BrowserAction) -> StepResult: async def step(self, action: BrowserAction) -> EnvStepResult:
if not self.page or self.page.is_closed(): if not self.page or self.page.is_closed():
return StepResult(error="Browser page is not available or closed.") return EnvStepResult(error="Browser page is not available or closed.")
output, error = None, None state = await self.get_state()
output, error, step_type = None, None, "text"
try: try:
match action.type: match action.type:
case "click": case "click":
@@ -389,8 +431,8 @@ class BrowserEnvironment(Environment):
logger.debug(f"Action: {action.type} for {duration}s") logger.debug(f"Action: {action.type} for {duration}s")
case "screenshot": case "screenshot":
# Screenshot is taken after every step, so this action might just confirm it step_type = "image"
output = "[Screenshot taken]" output = {"image": state.screenshot, "url": state.url}
logger.debug(f"Action: {action.type}") logger.debug(f"Action: {action.type}")
case "move": case "move":
@@ -462,28 +504,17 @@ class BrowserEnvironment(Environment):
error = f"Error executing action {action.type}: {e}" error = f"Error executing action {action.type}: {e}"
logger.exception(f"Error during step execution for action: {action.model_dump_json()}") logger.exception(f"Error during step execution for action: {action.model_dump_json()}")
state = await self.get_state() return EnvStepResult(
type=step_type,
# Special handling for screenshot action result to include image data
if action.type == "screenshot" and state.screenshot:
output = [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/webp",
"data": state.screenshot,
},
}
]
return StepResult(
output=output, output=output,
error=error, error=error,
current_url=state.url, current_url=state.url,
screenshot_base64=state.screenshot, screenshot_base64=state.screenshot,
) )
def reset(self) -> None:
self.visited_urls.clear()
async def close(self) -> None: async def close(self) -> None:
if self.browser: if self.browser:
await self.browser.close() await self.browser.close()
@@ -542,14 +573,14 @@ class BrowserEnvironment(Environment):
# --- Concrete Operator Agents --- # --- Concrete Operator Agents ---
class OpenAIOperatorAgent(OperatorAgent): class OpenAIOperatorAgent(OperatorAgent):
async def act(self, messages: List[dict], current_state: EnvState) -> AgentActResult: async def act(self, query: str, current_state: EnvState) -> AgentActResult:
client = get_openai_async_client( client = get_openai_async_client(
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url
) )
safety_check_prefix = "The user needs to say 'continue' after resolving the following safety checks to proceed:" safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
safety_check_message = None safety_check_message = None
actions: List[BrowserAction] = [] actions: List[BrowserAction] = []
tool_results_for_history: List[dict] = [] action_results: List[dict] = []
self._commit_trace() # Commit trace before next action self._commit_trace() # Commit trace before next action
system_prompt = f"""<SYSTEM_CAPABILITY> system_prompt = f"""<SYSTEM_CAPABILITY>
@@ -597,9 +628,13 @@ class OpenAIOperatorAgent(OperatorAgent):
}, },
] ]
if is_none_or_empty(self.messages):
self.messages = [ChatMessage(role="user", content=query)]
messages_for_api = self._format_message_for_api(self.messages)
response: Response = await client.responses.create( response: Response = await client.responses.create(
model="computer-use-preview", model="computer-use-preview",
input=messages, input=messages_for_api,
instructions=system_prompt, instructions=system_prompt,
tools=tools, tools=tools,
parallel_tool_calls=False, # Keep sequential for now parallel_tool_calls=False, # Keep sequential for now
@@ -608,14 +643,12 @@ class OpenAIOperatorAgent(OperatorAgent):
) )
logger.debug(f"Openai response: {response.model_dump_json()}") logger.debug(f"Openai response: {response.model_dump_json()}")
compiled_response = self.compile_openai_response(response.output) self.messages += [ChatMessage(role="environment", content=response.output)]
self.compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response)) rendered_response = await self._render_response(response.output, current_state.screenshot)
last_call_id = None last_call_id = None
for block in response.output: for block in response.output:
action_to_run: Optional[BrowserAction] = None action_to_run: Optional[BrowserAction] = None
content_for_history: Optional[Union[str, dict]] = None
if block.type == "function_call": if block.type == "function_call":
last_call_id = block.call_id last_call_id = block.call_id
if block.name == "goto": if block.name == "goto":
@@ -624,31 +657,24 @@ class OpenAIOperatorAgent(OperatorAgent):
url = args.get("url") url = args.get("url")
if url: if url:
action_to_run = GotoAction(url=url) action_to_run = GotoAction(url=url)
content_for_history = (
f"Navigated to {url}" # Placeholder, actual result comes from env.step
)
else: else:
logger.warning("Goto function called without URL argument.") logger.warning("Goto function called without URL argument.")
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Failed to parse arguments for goto: {block.arguments}") logger.warning(f"Failed to parse arguments for goto: {block.arguments}")
elif block.name == "back": elif block.name == "back":
action_to_run = BackAction() action_to_run = BackAction()
content_for_history = "Navigated back" # Placeholder
elif block.type == "computer_call": elif block.type == "computer_call":
last_call_id = block.call_id last_call_id = block.call_id
if block.pending_safety_checks: if block.pending_safety_checks:
for check in block.pending_safety_checks: safety_check_body = "\n- ".join([check.message for check in block.pending_safety_checks])
if safety_check_message: safety_check_message = f"{safety_check_prefix}\n- {safety_check_body}"
safety_check_message += f"\n- {check.message}" action_to_run = RequestUserAction(request=safety_check_message)
else: actions.append(action_to_run)
safety_check_message = f"{safety_check_prefix}\n- {check.message}"
break # Stop processing actions if safety check needed break # Stop processing actions if safety check needed
openai_action = block.action
content_for_history = "[placeholder for screenshot]" # Placeholder
# Convert OpenAI action to standardized BrowserAction # Convert OpenAI action to standardized BrowserAction
openai_action = block.action
action_type = openai_action.type action_type = openai_action.type
try: try:
if action_type == "click": if action_type == "click":
@@ -682,10 +708,10 @@ class OpenAIOperatorAgent(OperatorAgent):
if action_to_run: if action_to_run:
actions.append(action_to_run) actions.append(action_to_run)
# Prepare the result structure expected in the message history # Prepare the result structure expected in the message history
tool_results_for_history.append( action_results.append(
{ {
"type": f"{block.type}_output", "type": f"{block.type}_output",
"output": content_for_history, # This will be updated after env.step "output": None, # Updated by environment step
"call_id": last_call_id, "call_id": last_call_id,
} }
) )
@@ -693,22 +719,84 @@ class OpenAIOperatorAgent(OperatorAgent):
self._update_usage(response.usage.input_tokens, response.usage.output_tokens) self._update_usage(response.usage.input_tokens, response.usage.output_tokens)
return AgentActResult( return AgentActResult(
compiled_response=compiled_response,
actions=actions, actions=actions,
tool_results_for_history=tool_results_for_history, action_results=action_results,
raw_agent_response=response, rendered_response=rendered_response,
usage=self.tracer.get("usage", {}),
safety_check_message=safety_check_message,
) )
def add_action_results(
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
) -> None:
if not agent_action.action_results and not summarize_prompt:
return
# Update action results with results of applying suggested actions on the environment
for idx, env_step in enumerate(env_steps):
action_result = agent_action.action_results[idx]
result_content = env_step.error or env_step.output or "[Action completed]"
if env_step.type == "image":
# Add screenshot data in openai message format
action_result["output"] = {
"type": "input_image",
"image_url": f'data:image/webp;base64,{result_content["image"]}',
"current_url": result_content["url"],
}
elif action_result["type"] == "computer_call_output" and idx == len(env_steps) - 1:
# Always add screenshot, current url to last action result, when computer tool used
action_result["output"] = {
"type": "input_image",
"image_url": f"data:image/webp;base64,{env_step.screenshot_base64}",
"current_url": env_step.current_url,
}
else:
# Add text data
action_result["output"] = result_content
if agent_action.action_results:
self.messages += [ChatMessage(role="environment", content=agent_action.action_results)]
# Append summarize prompt as a user message after tool results
if summarize_prompt:
self.messages += [ChatMessage(role="user", content=summarize_prompt)]
def _format_message_for_api(self, messages: list[ChatMessage]) -> list:
"""Format the message for OpenAI API."""
formatted_messages = []
for message in messages:
if message.role == "environment":
formatted_messages.extend(message.content)
else:
formatted_messages.append(
{
"role": message.role,
"content": message.content,
}
)
return formatted_messages
@staticmethod @staticmethod
def compile_openai_response(response_content: list[ResponseOutputItem]) -> str: def compile_response(response_content: str | list[dict | ResponseOutputItem]) -> str:
"""Compile the response from Open AI model into a single string.""" """Compile the response from model into a single string."""
# Handle case where response content is a string.
# This is the case when response content is a user query
if is_none_or_empty(response_content) or isinstance(response_content, str):
return response_content
# Handle case where response_content is a dictionary and not ResponseOutputItem
# This is the case when response_content contains action results
if not hasattr(response_content[0], "type"):
return "**Action**: " + json.dumps(response_content[0]["output"])
compiled_response = [""] compiled_response = [""]
for block in deepcopy(response_content): for block in deepcopy(response_content):
if block.type == "message": if block.type == "message":
# Extract text content if available # Extract text content if available
text_content = block.text if hasattr(block, "text") else block.model_dump_json() for content in block.content:
text_content = ""
if hasattr(content, "text"):
text_content += content.text
elif hasattr(content, "refusal"):
text_content += f"Refusal: {content.refusal}"
else:
text_content += content.model_dump_json()
compiled_response.append(text_content) compiled_response.append(text_content)
elif block.type == "function_call": elif block.type == "function_call":
block_input = {"action": block.name} block_input = {"action": block.name}
@@ -721,8 +809,9 @@ class OpenAIOperatorAgent(OperatorAgent):
compiled_response.append(f"**Action**: {json.dumps(block_input)}") compiled_response.append(f"**Action**: {json.dumps(block_input)}")
elif block.type == "computer_call": elif block.type == "computer_call":
block_input = block.action block_input = block.action
# If it's a screenshot action and we have a screenshot, render it # If it's a screenshot action
if block_input.type == "screenshot": if block_input.type == "screenshot":
# Use a placeholder for screenshot data
block_input_render = block_input.model_dump() block_input_render = block_input.model_dump()
block_input_render["image"] = "[placeholder for screenshot]" block_input_render["image"] = "[placeholder for screenshot]"
compiled_response.append(f"**Action**: {json.dumps(block_input_render)}") compiled_response.append(f"**Action**: {json.dumps(block_input_render)}")
@@ -733,13 +822,13 @@ class OpenAIOperatorAgent(OperatorAgent):
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
@staticmethod @staticmethod
async def render_response(response_content: list[ResponseOutputItem], screenshot: Optional[str] = None) -> str: async def _render_response(response_content: list[ResponseOutputItem], screenshot: Optional[str] = None) -> str:
"""Render OpenAI response for display, potentially including screenshots.""" """Render OpenAI response for display, potentially including screenshots."""
compiled_response = [""] rendered_response = [""]
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
if block.type == "message": if block.type == "message":
text_content = block.text if hasattr(block, "text") else block.model_dump_json() text_content = block.text if hasattr(block, "text") else block.model_dump_json()
compiled_response.append(text_content) rendered_response.append(text_content)
elif block.type == "function_call": elif block.type == "function_call":
block_input = {"action": block.name} block_input = {"action": block.name}
if block.name == "goto": if block.name == "goto":
@@ -748,26 +837,27 @@ class OpenAIOperatorAgent(OperatorAgent):
block_input["url"] = args.get("url", "[Missing URL]") block_input["url"] = args.get("url", "[Missing URL]")
except json.JSONDecodeError: except json.JSONDecodeError:
block_input["arguments"] = block.arguments block_input["arguments"] = block.arguments
compiled_response.append(f"**Action**: {json.dumps(block_input)}") rendered_response.append(f"**Action**: {json.dumps(block_input)}")
elif block.type == "computer_call": elif block.type == "computer_call":
block_input = block.action block_input = block.action
# If it's a screenshot action and we have a screenshot, render it # If it's a screenshot action
if block_input.type == "screenshot": if block_input.type == "screenshot":
# Render screenshot if available
block_input_render = block_input.model_dump() block_input_render = block_input.model_dump()
if screenshot: if screenshot:
block_input_render["image"] = f"data:image/webp;base64,{screenshot}" block_input_render["image"] = f"data:image/webp;base64,{screenshot}"
else: else:
block_input_render["image"] = "[Failed to get screenshot]" block_input_render["image"] = "[Failed to get screenshot]"
compiled_response.append(f"**Action**: {json.dumps(block_input_render)}") rendered_response.append(f"**Action**: {json.dumps(block_input_render)}")
else: else:
compiled_response.append(f"**Action**: {block_input.model_dump_json()}") rendered_response.append(f"**Action**: {block_input.model_dump_json()}")
elif block.type == "reasoning" and block.summary: elif block.type == "reasoning" and block.summary:
compiled_response.append(f"**Thought**: {block.summary}") rendered_response.append(f"**Thought**: {block.summary}")
return "\n- ".join(filter(None, compiled_response)) return "\n- ".join(filter(None, rendered_response))
class AnthropicOperatorAgent(OperatorAgent): class AnthropicOperatorAgent(OperatorAgent):
async def act(self, messages: List[dict], current_state: EnvState) -> AgentActResult: async def act(self, query: str, current_state: EnvState) -> AgentActResult:
client = get_anthropic_async_client( client = get_anthropic_async_client(
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url
) )
@@ -775,7 +865,7 @@ class AnthropicOperatorAgent(OperatorAgent):
betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"] betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"]
temperature = 1.0 temperature = 1.0
actions: List[BrowserAction] = [] actions: List[BrowserAction] = []
tool_results_for_history: List[dict] = [] action_results: List[dict] = []
self._commit_trace() # Commit trace before next action self._commit_trace() # Commit trace before next action
system_prompt = f"""<SYSTEM_CAPABILITY> system_prompt = f"""<SYSTEM_CAPABILITY>
@@ -797,11 +887,8 @@ class AnthropicOperatorAgent(OperatorAgent):
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting * After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
</IMPORTANT> </IMPORTANT>
""" """
# Add latest screenshot if available if is_none_or_empty(self.messages):
if current_state.screenshot: self.messages = [ChatMessage(role="user", content=query)]
# Ensure last message content is a list
if not isinstance(messages[-1]["content"], list):
messages[-1]["content"] = [{"type": "text", "text": messages[-1]["content"]}]
tools = [ tools = [
{ {
@@ -830,8 +917,9 @@ class AnthropicOperatorAgent(OperatorAgent):
if self.chat_model.name.startswith("claude-3-7"): if self.chat_model.name.startswith("claude-3-7"):
thinking = {"type": "enabled", "budget_tokens": 1024} thinking = {"type": "enabled", "budget_tokens": 1024}
messages_for_api = self._format_message_for_api(self.messages)
response = await client.beta.messages.create( response = await client.beta.messages.create(
messages=messages, messages=messages_for_api,
model=self.chat_model.name, model=self.chat_model.name,
system=system_prompt, system=system_prompt,
tools=tools, tools=tools,
@@ -842,120 +930,102 @@ class AnthropicOperatorAgent(OperatorAgent):
) )
logger.debug(f"Anthropic response: {response.model_dump_json()}") logger.debug(f"Anthropic response: {response.model_dump_json()}")
compiled_response = self.compile_response(response.content) self.messages.append(ChatMessage(role="assistant", content=response.content))
self.compiled_operator_messages.append( rendered_response = await self._render_response(response.content, current_state.screenshot)
ChatMessage(role="assistant", content=compiled_response)
) # Add raw response text
for block in response.content: for block in response.content:
if block.type == "tool_use": if block.type == "tool_use":
action_to_run: Optional[BrowserAction] = None action_to_run: Optional[BrowserAction] = None
tool_input = block.input tool_input = block.input
tool_name = block.name tool_name = block.input.get("action") if block.name == "computer" else block.name
tool_use_id = block.id tool_use_id = block.id
content_for_history: Optional[Union[str, List[dict]]] = None
try: try:
if tool_name == "computer": if tool_name == "mouse_move":
action_type = tool_input.get("action")
content_for_history = "[placeholder for screenshot]" # Default placeholder
if action_type == "mouse_move":
coord = tool_input.get("coordinate") coord = tool_input.get("coordinate")
if coord: if coord:
action_to_run = MoveAction(x=coord[0], y=coord[1]) action_to_run = MoveAction(x=coord[0], y=coord[1])
elif action_type == "left_click": elif tool_name == "left_click":
coord = tool_input.get("coordinate") coord = tool_input.get("coordinate")
if coord: if coord:
action_to_run = ClickAction( action_to_run = ClickAction(
x=coord[0], y=coord[1], button="left", modifier=tool_input.get("text") x=coord[0], y=coord[1], button="left", modifier=tool_input.get("text")
) )
elif action_type == "right_click": elif tool_name == "right_click":
coord = tool_input.get("coordinate") coord = tool_input.get("coordinate")
if coord: if coord:
action_to_run = ClickAction(x=coord[0], y=coord[1], button="right") action_to_run = ClickAction(x=coord[0], y=coord[1], button="right")
elif action_type == "middle_click": elif tool_name == "middle_click":
coord = tool_input.get("coordinate") coord = tool_input.get("coordinate")
if coord: if coord:
action_to_run = ClickAction(x=coord[0], y=coord[1], button="middle") action_to_run = ClickAction(x=coord[0], y=coord[1], button="middle")
elif action_type == "double_click": elif tool_name == "double_click":
coord = tool_input.get("coordinate") coord = tool_input.get("coordinate")
if coord: if coord:
action_to_run = DoubleClickAction(x=coord[0], y=coord[1]) action_to_run = DoubleClickAction(x=coord[0], y=coord[1])
elif action_type == "triple_click": elif tool_name == "triple_click":
coord = tool_input.get("coordinate") coord = tool_input.get("coordinate")
if coord: if coord:
action_to_run = TripleClickAction(x=coord[0], y=coord[1]) action_to_run = TripleClickAction(x=coord[0], y=coord[1])
elif action_type == "left_click_drag": elif tool_name == "left_click_drag":
start_coord = tool_input.get("start_coordinate") start_coord = tool_input.get("start_coordinate")
end_coord = tool_input.get("coordinate") end_coord = tool_input.get("coordinate")
if start_coord and end_coord: if start_coord and end_coord:
action_to_run = DragAction( action_to_run = DragAction(path=[Point(x=p[0], y=p[1]) for p in [start_coord, end_coord]])
path=[Point(x=p[0], y=p[1]) for p in [start_coord, end_coord]] elif tool_name == "left_mouse_down":
)
elif action_type == "left_mouse_down":
action_to_run = MouseDownAction(button="left") action_to_run = MouseDownAction(button="left")
elif action_type == "left_mouse_up": elif tool_name == "left_mouse_up":
action_to_run = MouseUpAction(button="left") action_to_run = MouseUpAction(button="left")
elif action_type == "type": elif tool_name == "type":
text = tool_input.get("text") text = tool_input.get("text")
if text: if text:
action_to_run = TypeAction(text=text) action_to_run = TypeAction(text=text)
elif action_type == "scroll": elif tool_name == "scroll":
direction = tool_input.get("scroll_direction") direction = tool_input.get("scroll_direction")
amount = tool_input.get("scroll_amount", 1) amount = tool_input.get("scroll_amount", 5)
coord = tool_input.get("coordinate") coord = tool_input.get("coordinate")
x = coord[0] if coord else None x = coord[0] if coord else None
y = coord[1] if coord else None y = coord[1] if coord else None
if direction: if direction:
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y) action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y)
elif action_type == "key": elif tool_name == "key":
text: str = tool_input.get("text") text: str = tool_input.get("text")
if text: if text:
action_to_run = KeypressAction(keys=text.split("+")) # Split xdotool style action_to_run = KeypressAction(keys=text.split("+")) # Split xdotool style
elif action_type == "hold_key": elif tool_name == "hold_key":
text = tool_input.get("text") text = tool_input.get("text")
duration = tool_input.get("duration", 1.0) duration = tool_input.get("duration", 1.0)
if text: if text:
action_to_run = HoldKeyAction(text=text, duration=duration) action_to_run = HoldKeyAction(text=text, duration=duration)
elif action_type == "wait": elif tool_name == "wait":
duration = tool_input.get("duration", 1.0) duration = tool_input.get("duration", 1.0)
action_to_run = WaitAction(duration=duration) action_to_run = WaitAction(duration=duration)
elif action_type == "screenshot": elif tool_name == "screenshot":
action_to_run = ScreenshotAction() action_to_run = ScreenshotAction()
content_for_history = [ elif tool_name == "cursor_position":
{
"type": "image",
"source": {"type": "base64", "media_type": "image/webp", "data": "[placeholder]"},
}
]
elif action_type == "cursor_position":
action_to_run = CursorPositionAction() action_to_run = CursorPositionAction()
else:
logger.warning(f"Unsupported Anthropic computer action type: {action_type}")
elif tool_name == "goto": elif tool_name == "goto":
url = tool_input.get("url") url = tool_input.get("url")
if url: if url:
action_to_run = GotoAction(url=url) action_to_run = GotoAction(url=url)
content_for_history = f"Navigated to {url}"
else: else:
logger.warning("Goto tool called without URL.") logger.warning("Goto tool called without URL.")
elif tool_name == "back": elif tool_name == "back":
action_to_run = BackAction() action_to_run = BackAction()
content_for_history = "Navigated back" else:
logger.warning(f"Unsupported Anthropic computer action type: {tool_name}")
except Exception as e: except Exception as e:
logger.error(f"Error converting Anthropic action {tool_name} ({tool_input}): {e}") logger.error(f"Error converting Anthropic action {tool_name} ({tool_input}): {e}")
if action_to_run: if action_to_run:
actions.append(action_to_run) actions.append(action_to_run)
# Prepare the result structure expected in the message history action_results.append(
tool_results_for_history.append(
{ {
"type": "tool_result", "type": "tool_result",
"tool_use_id": tool_use_id, "tool_use_id": tool_use_id,
"content": content_for_history, # This will be updated after env.step "content": None, # Updated by environment step
"is_error": False, # Will be updated after env.step "is_error": False, # Updated by environment step
} }
) )
@@ -968,16 +1038,79 @@ class AnthropicOperatorAgent(OperatorAgent):
self.tracer["temperature"] = temperature self.tracer["temperature"] = temperature
return AgentActResult( return AgentActResult(
compiled_response=compiled_response,
actions=actions, actions=actions,
tool_results_for_history=tool_results_for_history, action_results=action_results,
raw_agent_response=response, rendered_response=rendered_response,
usage=self.tracer.get("usage", {}),
safety_check_message=None, # Anthropic doesn't have this yet
) )
def compile_response(self, response_content: list[BetaContentBlock]) -> str: def add_action_results(
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
):
if not agent_action.action_results and not summarize_prompt:
return
elif not agent_action.action_results:
agent_action.action_results = []
# Update action results with results of applying suggested actions on the environment
for idx, env_step in enumerate(env_steps):
action_result = agent_action.action_results[idx]
result_content = env_step.error or env_step.output or "[Action completed]"
if env_step.type == "image":
# Add screenshot data in anthropic message format
action_result["content"] = [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/webp",
"data": result_content["image"],
},
}
]
else:
# Add text data
action_result["content"] = result_content
if env_step.error:
action_result["is_error"] = True
# If summarize prompt provided, append as text within the tool results user message
if summarize_prompt:
agent_action.action_results.append({"type": "text", "text": summarize_prompt})
# Append tool results to the message history
self.messages += [ChatMessage(role="environment", content=agent_action.action_results)]
# Mark the final tool result as a cache break point
agent_action.action_results[-1]["cache_control"] = {"type": "ephemeral"}
# Remove previous cache controls
for msg in self.messages:
if msg.role == "environment" and isinstance(msg.content, list):
for block in msg.content:
if isinstance(block, dict) and "cache_control" in block:
del block["cache_control"]
def _format_message_for_api(self, messages: list[ChatMessage]) -> list[dict]:
"""Format Anthropic response into a single string."""
formatted_messages = []
for message in messages:
role = "user" if message.role == "environment" else message.role
content = (
[{"type": "text", "text": message.content}]
if not isinstance(message.content, list)
else message.content
)
formatted_messages.append(
{
"role": role,
"content": content,
}
)
return formatted_messages
def compile_response(self, response_content: list[BetaContentBlock | Any]) -> str:
"""Compile Anthropic response into a single string.""" """Compile Anthropic response into a single string."""
if is_none_or_empty(response_content) or not all(hasattr(item, "type") for item in response_content):
return response_content
compiled_response = [""] compiled_response = [""]
for block in deepcopy(response_content): for block in deepcopy(response_content):
if block.type == "text": if block.type == "text":
@@ -1004,12 +1137,12 @@ class AnthropicOperatorAgent(OperatorAgent):
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
@staticmethod @staticmethod
async def render_response(response_content: list[BetaContentBlock], screenshot: Optional[str] = None) -> str: async def _render_response(response_content: list[BetaContentBlock], screenshot: Optional[str] = None) -> str:
"""Render Anthropic response, potentially including actual screenshots.""" """Render Anthropic response, potentially including actual screenshots."""
compiled_response = [""] rendered_response = [""]
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
if block.type == "text": if block.type == "text":
compiled_response.append(block.text) rendered_response.append(block.text)
elif block.type == "tool_use": elif block.type == "tool_use":
block_input = {"action": block.name} block_input = {"action": block.name}
if block.name == "computer": if block.name == "computer":
@@ -1017,20 +1150,21 @@ class AnthropicOperatorAgent(OperatorAgent):
elif block.name == "goto": elif block.name == "goto":
block_input["url"] = block.input.get("url", "[Missing URL]") block_input["url"] = block.input.get("url", "[Missing URL]")
# If it's a screenshot action and we have a page, get the actual screenshot # If it's a screenshot action
if isinstance(block_input, dict) and block_input.get("action") == "screenshot": if isinstance(block_input, dict) and block_input.get("action") == "screenshot":
# Render the screenshot data if available
if screenshot: if screenshot:
block_input["image"] = f"data:image/webp;base64,{screenshot}" block_input["image"] = f"data:image/webp;base64,{screenshot}"
else: else:
block_input["image"] = "[Failed to get screenshot]" block_input["image"] = "[Failed to get screenshot]"
compiled_response.append(f"**Action**: {json.dumps(block_input)}") rendered_response.append(f"**Action**: {json.dumps(block_input)}")
elif block.type == "thinking": elif block.type == "thinking":
thinking_content = getattr(block, "thinking", None) thinking_content = getattr(block, "thinking", None)
if thinking_content: if thinking_content:
compiled_response.append(f"**Thought**: {thinking_content}") rendered_response.append(f"**Thought**: {thinking_content}")
return "\n- ".join(filter(None, compiled_response)) return "\n- ".join(filter(None, rendered_response))
# --- Main Operator Function --- # --- Main Operator Function ---
@@ -1046,11 +1180,10 @@ async def operate_browser(
cancellation_event: Optional[asyncio.Event] = None, cancellation_event: Optional[asyncio.Event] = None,
tracer: dict = {}, tracer: dict = {},
): ):
response, safety_check_message = None, None response, summary_message, user_input_message = None, None, None
final_compiled_response = ""
environment: Optional[BrowserEnvironment] = None environment: Optional[BrowserEnvironment] = None
try: # Get the agent chat model
agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model) chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model)
supported_operator_model_types = [ChatModel.ModelType.OPENAI, ChatModel.ModelType.ANTHROPIC] supported_operator_model_types = [ChatModel.ModelType.OPENAI, ChatModel.ModelType.ANTHROPIC]
@@ -1060,14 +1193,6 @@ async def operate_browser(
f"Unsupported AI model. Configure and use chat model of type {supported_operator_model_types} to enable Browser use." f"Unsupported AI model. Configure and use chat model of type {supported_operator_model_types} to enable Browser use."
) )
if send_status_func:
async for event in send_status_func(f"**Launching Browser**"):
yield {ChatEvent.STATUS: event}
# Initialize Environment
environment = BrowserEnvironment()
await environment.start(width=1024, height=768)
# Initialize Agent # Initialize Agent
max_iterations = 40 # TODO: Configurable? max_iterations = 40 # TODO: Configurable?
operator_agent: OperatorAgent operator_agent: OperatorAgent
@@ -1078,139 +1203,76 @@ async def operate_browser(
else: # Should not happen due to check above, but satisfy type checker else: # Should not happen due to check above, but satisfy type checker
raise ValueError("Invalid model type for operator agent.") raise ValueError("Invalid model type for operator agent.")
messages = [{"role": "user", "content": query}] # Initialize Environment
run_summarize = False if send_status_func:
async for event in send_status_func(f"**Launching Browser**"):
yield {ChatEvent.STATUS: event}
environment = BrowserEnvironment()
await environment.start(width=1024, height=768)
# Start Operator Loop
try:
summarize_prompt = (
f"Collate all relevant information from your research so far to answer the target query:\n{query}."
)
task_completed = False task_completed = False
iterations = 0 iterations = 0
with timer(f"Operating browser with {chat_model.model_type} {chat_model.name}", logger): with timer(f"Operating browser with {chat_model.model_type} {chat_model.name}", logger):
while iterations < max_iterations: while iterations < max_iterations and not task_completed:
if cancellation_event and cancellation_event.is_set(): if cancellation_event and cancellation_event.is_set():
logger.info(f"Browser operator cancelled by client disconnect") logger.info(f"Browser operator cancelled by client disconnect")
break break
iterations += 1 iterations += 1
# Get current environment state # 1. Get current environment state
browser_state = await environment.get_state() browser_state = await environment.get_state()
# Agent decides action(s) # 2. Agent decides action(s)
agent_result = await operator_agent.act(deepcopy(messages), browser_state) agent_result = await operator_agent.act(query, browser_state)
final_compiled_response = agent_result.compiled_response # Update final response each turn
safety_check_message = agent_result.safety_check_message
# Update conversation history with agent's response (before tool results)
if chat_model.model_type == ChatModel.ModelType.OPENAI:
# OpenAI expects list of blocks in 'content' for assistant message with tool calls
messages += agent_result.raw_agent_response.output
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
messages.append({"role": "assistant", "content": agent_result.raw_agent_response.content})
# Render status update # Render status update
rendered_response = agent_result.compiled_response # Default rendering rendered_response = agent_result.rendered_response
if chat_model.model_type == ChatModel.ModelType.ANTHROPIC: if send_status_func and rendered_response:
rendered_response = await operator_agent.render_response(
agent_result.raw_agent_response.content, browser_state.screenshot
)
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
rendered_response = await operator_agent.render_response(
agent_result.raw_agent_response.output, browser_state.screenshot
)
if send_status_func:
async for event in send_status_func(f"**Operating Browser**:\n{rendered_response}"): async for event in send_status_func(f"**Operating Browser**:\n{rendered_response}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
# Execute actions in the environment # 3. Execute actions in the environment
step_results: List[StepResult] = [] env_steps: List[EnvStepResult] = []
if not safety_check_message:
for action in agent_result.actions: for action in agent_result.actions:
if cancellation_event and cancellation_event.is_set(): if cancellation_event and cancellation_event.is_set():
break break
step_result = await environment.step(action) # Handle request for user action and break the loop
step_results.append(step_result) if isinstance(action, RequestUserAction):
user_input_message = action.request
# Gather results from actions if send_status_func:
for step_result in step_results: async for event in send_status_func(f"**Requesting User Input**:\n{action.request}"):
# Update the placeholder content in the history structure yield {ChatEvent.STATUS: event}
result_for_history = agent_result.tool_results_for_history[len(step_results) - 1] break
result_content = step_result.error or step_result.output or "[Action completed]" env_step = await environment.step(action)
if chat_model.model_type == ChatModel.ModelType.OPENAI: env_steps.append(env_step)
if result_for_history["type"] == "computer_call_output":
result_for_history["output"] = {
"type": "input_image",
"image_url": f"data:image/webp;base64,{step_result.screenshot_base64}",
}
result_for_history["output"]["current_url"] = step_result.current_url
else:
result_for_history["output"] = result_content
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
result_for_history["content"] = result_content
if step_result.error:
result_for_history["is_error"] = True
# Add browser message to compiled log for tracing
operator_agent.compiled_operator_messages.append(
ChatMessage(role="browser", content=str(result_content))
)
# Check summarization conditions
summarize_prompt = (
f"Collate all relevant information from your research so far to answer the target query:\n{query}."
)
task_completed = not agent_result.actions and not run_summarize # No actions requested by agent
trigger_iteration_limit = iterations == max_iterations and not run_summarize
# Check if termination conditions are met
task_completed = not agent_result.actions # No actions requested by agent
trigger_iteration_limit = iterations == max_iterations
if task_completed or trigger_iteration_limit: if task_completed or trigger_iteration_limit:
iterations = max_iterations - 1 # Ensure one more iteration for summarization # Summarize results of operator run on last iteration
run_summarize = True operator_agent.add_action_results(env_steps, agent_result, summarize_prompt)
logger.info( summary_message = await operator_agent.summarize(query, browser_state)
f"Triggering summarization. Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}" logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}")
)
# Append summarize prompt differently based on model
if chat_model.model_type == ChatModel.ModelType.OPENAI:
# Pop the last tool result if max iterations reached and agent attempted a tool call
if trigger_iteration_limit and agent_result.tool_results_for_history:
agent_result.tool_results_for_history.pop()
# Append summarize prompt as a user message after tool results
messages += agent_result.tool_results_for_history # Add results first
messages.append({"role": "user", "content": summarize_prompt})
agent_result.tool_results_for_history = [] # Clear results as they are now in messages
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
# Append summarize prompt as text within the tool results user message
agent_result.tool_results_for_history.append({"type": "text", "text": summarize_prompt})
# Add tool results to messages for the next iteration (if not handled above for OpenAI summarize)
if agent_result.tool_results_for_history:
if chat_model.model_type == ChatModel.ModelType.OPENAI:
messages += agent_result.tool_results_for_history
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
# Mark the final tool result as a cache break point
agent_result.tool_results_for_history[-1]["cache_control"] = {"type": "ephemeral"}
# Remove previous cache controls (Anthropic specific)
for msg in messages:
if msg["role"] == "user" and isinstance(msg["content"], list):
for block in msg["content"]:
if isinstance(block, dict) and "cache_control" in block:
del block["cache_control"]
messages.append({"role": "user", "content": agent_result.tool_results_for_history})
# Exit if safety checks are pending
if safety_check_message:
logger.warning(f"Safety check triggered: {safety_check_message}")
break break
# Determine final response message # 4. Update agent on the results of its action on the environment
if task_completed and not safety_check_message: operator_agent.add_action_results(env_steps, agent_result)
response = final_compiled_response
elif safety_check_message:
response = safety_check_message # Return safety message if that's why we stopped
else: # Hit iteration limit
response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{final_compiled_response}"
# Determine final response message
if user_input_message:
response = user_input_message
elif task_completed:
response = summary_message
else: # Hit iteration limit
response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}"
except requests.RequestException as e: except requests.RequestException as e:
error_msg = f"Browser use failed due to a network error: {e}" error_msg = f"Browser use failed due to a network error: {e}"
logger.error(error_msg) logger.error(error_msg)
@@ -1220,10 +1282,10 @@ async def operate_browser(
logger.exception(error_msg) # Log full traceback for unexpected errors logger.exception(error_msg) # Log full traceback for unexpected errors
raise ValueError(error_msg) raise ValueError(error_msg)
finally: finally:
if environment and not safety_check_message: # Don't close browser if safety check pending if environment and not user_input_message: # Don't close browser if user input required
await environment.close() await environment.close()
yield { yield {
"text": safety_check_message or response, "text": user_input_message or response,
"webpages": [{"link": url, "snippet": ""} for url in environment.visited_urls], "webpages": [{"link": url, "snippet": ""} for url in environment.visited_urls],
} }