From 4db888cd62c2942bd75e3749af8940689649e50d Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sun, 4 May 2025 18:39:12 -0600 Subject: [PATCH] Simplify operator loop. Make each OperatorAgent manage state internally. Remove each OperatorAgent specific code from leaking out into the operator. The Oprator just calls the standard OperatorAgent functions. Each AgentOperator specific logic is handled by the OperatorAgent internally. The improve the separation of responsibility between the operator, OperatorAgent and the Environment. - Make environment pass screenshot data in agent agnostic format - Have operator agents providers format image data to their AI model specific format - Add environment step type to distinguish image vs text content - Clearly mark major steps in the operator iteration loop - Handle anthropic models returning computer tool actions as normal tool calls by normalizing next action retrieval from response for it - Remove unused ActionResults fields - Remove unnnecessary placeholders to content of action results like for screenshot data --- .../processor/operator/browser_operator.py | 722 ++++++++++-------- 1 file changed, 392 insertions(+), 330 deletions(-) diff --git a/src/khoj/processor/operator/browser_operator.py b/src/khoj/processor/operator/browser_operator.py index 7ecaf0be..795fc761 100644 --- a/src/khoj/processor/operator/browser_operator.py +++ b/src/khoj/processor/operator/browser_operator.py @@ -6,11 +6,10 @@ import os from abc import ABC, abstractmethod from copy import deepcopy from datetime import datetime -from typing import Callable, List, Literal, Optional, Set, Union +from typing import Any, Callable, List, Literal, Optional, Set, Union import requests from anthropic.types.beta import BetaContentBlock, BetaMessage -from langchain.schema import ChatMessage from openai.types.responses import Response, ResponseOutputItem from playwright.async_api import Browser, Page, Playwright, async_playwright from pydantic import BaseModel @@ -24,6 +23,7 @@ from khoj.utils.helpers import ( get_anthropic_async_client, get_chat_usage_metrics, get_openai_async_client, + is_none_or_empty, is_promptrace_enabled, timer, ) @@ -69,7 +69,7 @@ class ScrollAction(BaseAction): scroll_x: Optional[int] = None scroll_y: Optional[int] = None scroll_direction: Optional[Literal["up", "down", "left", "right"]] = None - scroll_amount: Optional[int] = 1 + scroll_amount: Optional[int] = 5 class KeypressAction(BaseAction): @@ -131,6 +131,13 @@ class BackAction(BaseAction): type: Literal["back"] = "back" +class RequestUserAction(BaseAction): + """Request user action to confirm or provide input.""" + + type: Literal["request_user"] = "request_user" + request: str + + BrowserAction = Union[ ClickAction, DoubleClickAction, @@ -156,20 +163,23 @@ class EnvState(BaseModel): screenshot: Optional[str] = None -class StepResult(BaseModel): - output: Optional[str | List[dict]] = None # Output message or screenshot data +class EnvStepResult(BaseModel): + type: Literal["text", "image"] = "text" + output: Optional[str | dict] = None error: Optional[str] = None current_url: Optional[str] = None screenshot_base64: Optional[str] = None class AgentActResult(BaseModel): - compiled_response: str actions: List[BrowserAction] = [] - tool_results_for_history: List[dict] = [] # Model-specific format for history - raw_agent_response: BetaMessage | Response # Store the raw response for history formatting - usage: dict = {} - safety_check_message: Optional[str] = None + action_results: List[dict] = [] # Model-specific format + rendered_response: Optional[str] = None + + +class ChatMessage(BaseModel): + role: Literal["user", "assistant", "system", "environment"] + content: Union[str, List] # --- Abstract Classes --- @@ -179,7 +189,7 @@ class Environment(ABC): pass @abstractmethod - async def step(self, action: BrowserAction) -> StepResult: + async def step(self, action: BrowserAction) -> EnvStepResult: pass @abstractmethod @@ -196,10 +206,36 @@ class OperatorAgent(ABC): self.chat_model = chat_model self.max_iterations = max_iterations self.tracer = tracer - self.compiled_operator_messages: List[ChatMessage] = [] + self.messages: List[ChatMessage] = [] @abstractmethod - async def act(self, messages: List[dict], current_state: EnvState) -> AgentActResult: + async def act(self, query: str, current_state: EnvState) -> AgentActResult: + pass + + @abstractmethod + def add_action_results( + self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None + ) -> None: + """Track results of agent actions on the environment.""" + pass + + async def summarize(self, query: str, current_state: EnvState) -> str: + """Summarize the agent's actions and results.""" + await self.act(query, current_state) + if not self.messages: + return "No actions to summarize." + return await self.compile_response(self.messages[-1].content) + + @abstractmethod + def compile_response(self, response: List) -> str: + pass + + @abstractmethod + def _render_response(self, response: List, screenshot: Optional[str]) -> Optional[str]: + pass + + @abstractmethod + def _format_message_for_api(self, message: ChatMessage) -> List: pass def _update_usage(self, input_tokens: int, output_tokens: int, cache_read: int = 0, cache_write: int = 0): @@ -210,10 +246,15 @@ class OperatorAgent(ABC): def _commit_trace(self): self.tracer["chat_model"] = self.chat_model.name - if is_promptrace_enabled() and self.compiled_operator_messages: - commit_conversation_trace( - self.compiled_operator_messages[:-1], self.compiled_operator_messages[-1].content, self.tracer - ) + if is_promptrace_enabled() and len(self.messages) > 1: + compiled_messages = [ + ChatMessage(role=msg.role, content=self.compile_response(msg.content)) for msg in self.messages + ] + commit_conversation_trace(compiled_messages[:-1], compiled_messages[-1].content, self.tracer) + + def reset(self): + """Reset the agent state.""" + self.messages = [] # --- Concrete BrowserEnvironment --- @@ -293,11 +334,12 @@ class BrowserEnvironment(Environment): screenshot = await self._get_screenshot() return EnvState(url=url, screenshot=screenshot) - async def step(self, action: BrowserAction) -> StepResult: + async def step(self, action: BrowserAction) -> EnvStepResult: if not self.page or self.page.is_closed(): - return StepResult(error="Browser page is not available or closed.") + return EnvStepResult(error="Browser page is not available or closed.") - output, error = None, None + state = await self.get_state() + output, error, step_type = None, None, "text" try: match action.type: case "click": @@ -389,8 +431,8 @@ class BrowserEnvironment(Environment): logger.debug(f"Action: {action.type} for {duration}s") case "screenshot": - # Screenshot is taken after every step, so this action might just confirm it - output = "[Screenshot taken]" + step_type = "image" + output = {"image": state.screenshot, "url": state.url} logger.debug(f"Action: {action.type}") case "move": @@ -462,28 +504,17 @@ class BrowserEnvironment(Environment): error = f"Error executing action {action.type}: {e}" logger.exception(f"Error during step execution for action: {action.model_dump_json()}") - state = await self.get_state() - - # Special handling for screenshot action result to include image data - if action.type == "screenshot" and state.screenshot: - output = [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/webp", - "data": state.screenshot, - }, - } - ] - - return StepResult( + return EnvStepResult( + type=step_type, output=output, error=error, current_url=state.url, screenshot_base64=state.screenshot, ) + def reset(self) -> None: + self.visited_urls.clear() + async def close(self) -> None: if self.browser: await self.browser.close() @@ -542,14 +573,14 @@ class BrowserEnvironment(Environment): # --- Concrete Operator Agents --- class OpenAIOperatorAgent(OperatorAgent): - async def act(self, messages: List[dict], current_state: EnvState) -> AgentActResult: + async def act(self, query: str, current_state: EnvState) -> AgentActResult: client = get_openai_async_client( self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url ) - safety_check_prefix = "The user needs to say 'continue' after resolving the following safety checks to proceed:" + safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:" safety_check_message = None actions: List[BrowserAction] = [] - tool_results_for_history: List[dict] = [] + action_results: List[dict] = [] self._commit_trace() # Commit trace before next action system_prompt = f""" @@ -597,9 +628,13 @@ class OpenAIOperatorAgent(OperatorAgent): }, ] + if is_none_or_empty(self.messages): + self.messages = [ChatMessage(role="user", content=query)] + + messages_for_api = self._format_message_for_api(self.messages) response: Response = await client.responses.create( model="computer-use-preview", - input=messages, + input=messages_for_api, instructions=system_prompt, tools=tools, parallel_tool_calls=False, # Keep sequential for now @@ -608,14 +643,12 @@ class OpenAIOperatorAgent(OperatorAgent): ) logger.debug(f"Openai response: {response.model_dump_json()}") - compiled_response = self.compile_openai_response(response.output) - self.compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response)) + self.messages += [ChatMessage(role="environment", content=response.output)] + rendered_response = await self._render_response(response.output, current_state.screenshot) last_call_id = None for block in response.output: action_to_run: Optional[BrowserAction] = None - content_for_history: Optional[Union[str, dict]] = None - if block.type == "function_call": last_call_id = block.call_id if block.name == "goto": @@ -624,31 +657,24 @@ class OpenAIOperatorAgent(OperatorAgent): url = args.get("url") if url: action_to_run = GotoAction(url=url) - content_for_history = ( - f"Navigated to {url}" # Placeholder, actual result comes from env.step - ) else: logger.warning("Goto function called without URL argument.") except json.JSONDecodeError: logger.warning(f"Failed to parse arguments for goto: {block.arguments}") elif block.name == "back": action_to_run = BackAction() - content_for_history = "Navigated back" # Placeholder elif block.type == "computer_call": last_call_id = block.call_id if block.pending_safety_checks: - for check in block.pending_safety_checks: - if safety_check_message: - safety_check_message += f"\n- {check.message}" - else: - safety_check_message = f"{safety_check_prefix}\n- {check.message}" + safety_check_body = "\n- ".join([check.message for check in block.pending_safety_checks]) + safety_check_message = f"{safety_check_prefix}\n- {safety_check_body}" + action_to_run = RequestUserAction(request=safety_check_message) + actions.append(action_to_run) break # Stop processing actions if safety check needed - openai_action = block.action - content_for_history = "[placeholder for screenshot]" # Placeholder - # Convert OpenAI action to standardized BrowserAction + openai_action = block.action action_type = openai_action.type try: if action_type == "click": @@ -682,10 +708,10 @@ class OpenAIOperatorAgent(OperatorAgent): if action_to_run: actions.append(action_to_run) # Prepare the result structure expected in the message history - tool_results_for_history.append( + action_results.append( { "type": f"{block.type}_output", - "output": content_for_history, # This will be updated after env.step + "output": None, # Updated by environment step "call_id": last_call_id, } ) @@ -693,22 +719,84 @@ class OpenAIOperatorAgent(OperatorAgent): self._update_usage(response.usage.input_tokens, response.usage.output_tokens) return AgentActResult( - compiled_response=compiled_response, actions=actions, - tool_results_for_history=tool_results_for_history, - raw_agent_response=response, - usage=self.tracer.get("usage", {}), - safety_check_message=safety_check_message, + action_results=action_results, + rendered_response=rendered_response, ) + def add_action_results( + self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None + ) -> None: + if not agent_action.action_results and not summarize_prompt: + return + + # Update action results with results of applying suggested actions on the environment + for idx, env_step in enumerate(env_steps): + action_result = agent_action.action_results[idx] + result_content = env_step.error or env_step.output or "[Action completed]" + if env_step.type == "image": + # Add screenshot data in openai message format + action_result["output"] = { + "type": "input_image", + "image_url": f'data:image/webp;base64,{result_content["image"]}', + "current_url": result_content["url"], + } + elif action_result["type"] == "computer_call_output" and idx == len(env_steps) - 1: + # Always add screenshot, current url to last action result, when computer tool used + action_result["output"] = { + "type": "input_image", + "image_url": f"data:image/webp;base64,{env_step.screenshot_base64}", + "current_url": env_step.current_url, + } + else: + # Add text data + action_result["output"] = result_content + + if agent_action.action_results: + self.messages += [ChatMessage(role="environment", content=agent_action.action_results)] + # Append summarize prompt as a user message after tool results + if summarize_prompt: + self.messages += [ChatMessage(role="user", content=summarize_prompt)] + + def _format_message_for_api(self, messages: list[ChatMessage]) -> list: + """Format the message for OpenAI API.""" + formatted_messages = [] + for message in messages: + if message.role == "environment": + formatted_messages.extend(message.content) + else: + formatted_messages.append( + { + "role": message.role, + "content": message.content, + } + ) + return formatted_messages + @staticmethod - def compile_openai_response(response_content: list[ResponseOutputItem]) -> str: - """Compile the response from Open AI model into a single string.""" + def compile_response(response_content: str | list[dict | ResponseOutputItem]) -> str: + """Compile the response from model into a single string.""" + # Handle case where response content is a string. + # This is the case when response content is a user query + if is_none_or_empty(response_content) or isinstance(response_content, str): + return response_content + # Handle case where response_content is a dictionary and not ResponseOutputItem + # This is the case when response_content contains action results + if not hasattr(response_content[0], "type"): + return "**Action**: " + json.dumps(response_content[0]["output"]) + compiled_response = [""] for block in deepcopy(response_content): if block.type == "message": # Extract text content if available - text_content = block.text if hasattr(block, "text") else block.model_dump_json() + for content in block.content: + text_content = "" + if hasattr(content, "text"): + text_content += content.text + elif hasattr(content, "refusal"): + text_content += f"Refusal: {content.refusal}" + else: + text_content += content.model_dump_json() compiled_response.append(text_content) elif block.type == "function_call": block_input = {"action": block.name} @@ -721,8 +809,9 @@ class OpenAIOperatorAgent(OperatorAgent): compiled_response.append(f"**Action**: {json.dumps(block_input)}") elif block.type == "computer_call": block_input = block.action - # If it's a screenshot action and we have a screenshot, render it + # If it's a screenshot action if block_input.type == "screenshot": + # Use a placeholder for screenshot data block_input_render = block_input.model_dump() block_input_render["image"] = "[placeholder for screenshot]" compiled_response.append(f"**Action**: {json.dumps(block_input_render)}") @@ -733,13 +822,13 @@ class OpenAIOperatorAgent(OperatorAgent): return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings @staticmethod - async def render_response(response_content: list[ResponseOutputItem], screenshot: Optional[str] = None) -> str: + async def _render_response(response_content: list[ResponseOutputItem], screenshot: Optional[str] = None) -> str: """Render OpenAI response for display, potentially including screenshots.""" - compiled_response = [""] + rendered_response = [""] for block in deepcopy(response_content): # Use deepcopy to avoid modifying original if block.type == "message": text_content = block.text if hasattr(block, "text") else block.model_dump_json() - compiled_response.append(text_content) + rendered_response.append(text_content) elif block.type == "function_call": block_input = {"action": block.name} if block.name == "goto": @@ -748,26 +837,27 @@ class OpenAIOperatorAgent(OperatorAgent): block_input["url"] = args.get("url", "[Missing URL]") except json.JSONDecodeError: block_input["arguments"] = block.arguments - compiled_response.append(f"**Action**: {json.dumps(block_input)}") + rendered_response.append(f"**Action**: {json.dumps(block_input)}") elif block.type == "computer_call": block_input = block.action - # If it's a screenshot action and we have a screenshot, render it + # If it's a screenshot action if block_input.type == "screenshot": + # Render screenshot if available block_input_render = block_input.model_dump() if screenshot: block_input_render["image"] = f"data:image/webp;base64,{screenshot}" else: block_input_render["image"] = "[Failed to get screenshot]" - compiled_response.append(f"**Action**: {json.dumps(block_input_render)}") + rendered_response.append(f"**Action**: {json.dumps(block_input_render)}") else: - compiled_response.append(f"**Action**: {block_input.model_dump_json()}") + rendered_response.append(f"**Action**: {block_input.model_dump_json()}") elif block.type == "reasoning" and block.summary: - compiled_response.append(f"**Thought**: {block.summary}") - return "\n- ".join(filter(None, compiled_response)) + rendered_response.append(f"**Thought**: {block.summary}") + return "\n- ".join(filter(None, rendered_response)) class AnthropicOperatorAgent(OperatorAgent): - async def act(self, messages: List[dict], current_state: EnvState) -> AgentActResult: + async def act(self, query: str, current_state: EnvState) -> AgentActResult: client = get_anthropic_async_client( self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url ) @@ -775,7 +865,7 @@ class AnthropicOperatorAgent(OperatorAgent): betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"] temperature = 1.0 actions: List[BrowserAction] = [] - tool_results_for_history: List[dict] = [] + action_results: List[dict] = [] self._commit_trace() # Commit trace before next action system_prompt = f""" @@ -797,11 +887,8 @@ class AnthropicOperatorAgent(OperatorAgent): * After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting """ - # Add latest screenshot if available - if current_state.screenshot: - # Ensure last message content is a list - if not isinstance(messages[-1]["content"], list): - messages[-1]["content"] = [{"type": "text", "text": messages[-1]["content"]}] + if is_none_or_empty(self.messages): + self.messages = [ChatMessage(role="user", content=query)] tools = [ { @@ -830,8 +917,9 @@ class AnthropicOperatorAgent(OperatorAgent): if self.chat_model.name.startswith("claude-3-7"): thinking = {"type": "enabled", "budget_tokens": 1024} + messages_for_api = self._format_message_for_api(self.messages) response = await client.beta.messages.create( - messages=messages, + messages=messages_for_api, model=self.chat_model.name, system=system_prompt, tools=tools, @@ -842,120 +930,102 @@ class AnthropicOperatorAgent(OperatorAgent): ) logger.debug(f"Anthropic response: {response.model_dump_json()}") - compiled_response = self.compile_response(response.content) - self.compiled_operator_messages.append( - ChatMessage(role="assistant", content=compiled_response) - ) # Add raw response text + self.messages.append(ChatMessage(role="assistant", content=response.content)) + rendered_response = await self._render_response(response.content, current_state.screenshot) for block in response.content: if block.type == "tool_use": action_to_run: Optional[BrowserAction] = None tool_input = block.input - tool_name = block.name + tool_name = block.input.get("action") if block.name == "computer" else block.name tool_use_id = block.id - content_for_history: Optional[Union[str, List[dict]]] = None try: - if tool_name == "computer": - action_type = tool_input.get("action") - content_for_history = "[placeholder for screenshot]" # Default placeholder - if action_type == "mouse_move": - coord = tool_input.get("coordinate") - if coord: - action_to_run = MoveAction(x=coord[0], y=coord[1]) - elif action_type == "left_click": - coord = tool_input.get("coordinate") - if coord: - action_to_run = ClickAction( - x=coord[0], y=coord[1], button="left", modifier=tool_input.get("text") - ) - elif action_type == "right_click": - coord = tool_input.get("coordinate") - if coord: - action_to_run = ClickAction(x=coord[0], y=coord[1], button="right") - elif action_type == "middle_click": - coord = tool_input.get("coordinate") - if coord: - action_to_run = ClickAction(x=coord[0], y=coord[1], button="middle") - elif action_type == "double_click": - coord = tool_input.get("coordinate") - if coord: - action_to_run = DoubleClickAction(x=coord[0], y=coord[1]) - elif action_type == "triple_click": - coord = tool_input.get("coordinate") - if coord: - action_to_run = TripleClickAction(x=coord[0], y=coord[1]) - elif action_type == "left_click_drag": - start_coord = tool_input.get("start_coordinate") - end_coord = tool_input.get("coordinate") - if start_coord and end_coord: - action_to_run = DragAction( - path=[Point(x=p[0], y=p[1]) for p in [start_coord, end_coord]] - ) - elif action_type == "left_mouse_down": - action_to_run = MouseDownAction(button="left") - elif action_type == "left_mouse_up": - action_to_run = MouseUpAction(button="left") - elif action_type == "type": - text = tool_input.get("text") - if text: - action_to_run = TypeAction(text=text) - elif action_type == "scroll": - direction = tool_input.get("scroll_direction") - amount = tool_input.get("scroll_amount", 1) - coord = tool_input.get("coordinate") - x = coord[0] if coord else None - y = coord[1] if coord else None - if direction: - action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y) - elif action_type == "key": - text: str = tool_input.get("text") - if text: - action_to_run = KeypressAction(keys=text.split("+")) # Split xdotool style - elif action_type == "hold_key": - text = tool_input.get("text") - duration = tool_input.get("duration", 1.0) - if text: - action_to_run = HoldKeyAction(text=text, duration=duration) - elif action_type == "wait": - duration = tool_input.get("duration", 1.0) - action_to_run = WaitAction(duration=duration) - elif action_type == "screenshot": - action_to_run = ScreenshotAction() - content_for_history = [ - { - "type": "image", - "source": {"type": "base64", "media_type": "image/webp", "data": "[placeholder]"}, - } - ] - elif action_type == "cursor_position": - action_to_run = CursorPositionAction() - else: - logger.warning(f"Unsupported Anthropic computer action type: {action_type}") - + if tool_name == "mouse_move": + coord = tool_input.get("coordinate") + if coord: + action_to_run = MoveAction(x=coord[0], y=coord[1]) + elif tool_name == "left_click": + coord = tool_input.get("coordinate") + if coord: + action_to_run = ClickAction( + x=coord[0], y=coord[1], button="left", modifier=tool_input.get("text") + ) + elif tool_name == "right_click": + coord = tool_input.get("coordinate") + if coord: + action_to_run = ClickAction(x=coord[0], y=coord[1], button="right") + elif tool_name == "middle_click": + coord = tool_input.get("coordinate") + if coord: + action_to_run = ClickAction(x=coord[0], y=coord[1], button="middle") + elif tool_name == "double_click": + coord = tool_input.get("coordinate") + if coord: + action_to_run = DoubleClickAction(x=coord[0], y=coord[1]) + elif tool_name == "triple_click": + coord = tool_input.get("coordinate") + if coord: + action_to_run = TripleClickAction(x=coord[0], y=coord[1]) + elif tool_name == "left_click_drag": + start_coord = tool_input.get("start_coordinate") + end_coord = tool_input.get("coordinate") + if start_coord and end_coord: + action_to_run = DragAction(path=[Point(x=p[0], y=p[1]) for p in [start_coord, end_coord]]) + elif tool_name == "left_mouse_down": + action_to_run = MouseDownAction(button="left") + elif tool_name == "left_mouse_up": + action_to_run = MouseUpAction(button="left") + elif tool_name == "type": + text = tool_input.get("text") + if text: + action_to_run = TypeAction(text=text) + elif tool_name == "scroll": + direction = tool_input.get("scroll_direction") + amount = tool_input.get("scroll_amount", 5) + coord = tool_input.get("coordinate") + x = coord[0] if coord else None + y = coord[1] if coord else None + if direction: + action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y) + elif tool_name == "key": + text: str = tool_input.get("text") + if text: + action_to_run = KeypressAction(keys=text.split("+")) # Split xdotool style + elif tool_name == "hold_key": + text = tool_input.get("text") + duration = tool_input.get("duration", 1.0) + if text: + action_to_run = HoldKeyAction(text=text, duration=duration) + elif tool_name == "wait": + duration = tool_input.get("duration", 1.0) + action_to_run = WaitAction(duration=duration) + elif tool_name == "screenshot": + action_to_run = ScreenshotAction() + elif tool_name == "cursor_position": + action_to_run = CursorPositionAction() elif tool_name == "goto": url = tool_input.get("url") if url: action_to_run = GotoAction(url=url) - content_for_history = f"Navigated to {url}" else: logger.warning("Goto tool called without URL.") elif tool_name == "back": action_to_run = BackAction() - content_for_history = "Navigated back" + else: + logger.warning(f"Unsupported Anthropic computer action type: {tool_name}") except Exception as e: logger.error(f"Error converting Anthropic action {tool_name} ({tool_input}): {e}") if action_to_run: actions.append(action_to_run) - # Prepare the result structure expected in the message history - tool_results_for_history.append( + action_results.append( { "type": "tool_result", "tool_use_id": tool_use_id, - "content": content_for_history, # This will be updated after env.step - "is_error": False, # Will be updated after env.step + "content": None, # Updated by environment step + "is_error": False, # Updated by environment step } ) @@ -968,16 +1038,79 @@ class AnthropicOperatorAgent(OperatorAgent): self.tracer["temperature"] = temperature return AgentActResult( - compiled_response=compiled_response, actions=actions, - tool_results_for_history=tool_results_for_history, - raw_agent_response=response, - usage=self.tracer.get("usage", {}), - safety_check_message=None, # Anthropic doesn't have this yet + action_results=action_results, + rendered_response=rendered_response, ) - def compile_response(self, response_content: list[BetaContentBlock]) -> str: + def add_action_results( + self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None + ): + if not agent_action.action_results and not summarize_prompt: + return + elif not agent_action.action_results: + agent_action.action_results = [] + + # Update action results with results of applying suggested actions on the environment + for idx, env_step in enumerate(env_steps): + action_result = agent_action.action_results[idx] + result_content = env_step.error or env_step.output or "[Action completed]" + if env_step.type == "image": + # Add screenshot data in anthropic message format + action_result["content"] = [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/webp", + "data": result_content["image"], + }, + } + ] + else: + # Add text data + action_result["content"] = result_content + if env_step.error: + action_result["is_error"] = True + + # If summarize prompt provided, append as text within the tool results user message + if summarize_prompt: + agent_action.action_results.append({"type": "text", "text": summarize_prompt}) + + # Append tool results to the message history + self.messages += [ChatMessage(role="environment", content=agent_action.action_results)] + + # Mark the final tool result as a cache break point + agent_action.action_results[-1]["cache_control"] = {"type": "ephemeral"} + # Remove previous cache controls + for msg in self.messages: + if msg.role == "environment" and isinstance(msg.content, list): + for block in msg.content: + if isinstance(block, dict) and "cache_control" in block: + del block["cache_control"] + + def _format_message_for_api(self, messages: list[ChatMessage]) -> list[dict]: + """Format Anthropic response into a single string.""" + formatted_messages = [] + for message in messages: + role = "user" if message.role == "environment" else message.role + content = ( + [{"type": "text", "text": message.content}] + if not isinstance(message.content, list) + else message.content + ) + formatted_messages.append( + { + "role": role, + "content": content, + } + ) + return formatted_messages + + def compile_response(self, response_content: list[BetaContentBlock | Any]) -> str: """Compile Anthropic response into a single string.""" + if is_none_or_empty(response_content) or not all(hasattr(item, "type") for item in response_content): + return response_content compiled_response = [""] for block in deepcopy(response_content): if block.type == "text": @@ -1004,12 +1137,12 @@ class AnthropicOperatorAgent(OperatorAgent): return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings @staticmethod - async def render_response(response_content: list[BetaContentBlock], screenshot: Optional[str] = None) -> str: + async def _render_response(response_content: list[BetaContentBlock], screenshot: Optional[str] = None) -> str: """Render Anthropic response, potentially including actual screenshots.""" - compiled_response = [""] + rendered_response = [""] for block in deepcopy(response_content): # Use deepcopy to avoid modifying original if block.type == "text": - compiled_response.append(block.text) + rendered_response.append(block.text) elif block.type == "tool_use": block_input = {"action": block.name} if block.name == "computer": @@ -1017,20 +1150,21 @@ class AnthropicOperatorAgent(OperatorAgent): elif block.name == "goto": block_input["url"] = block.input.get("url", "[Missing URL]") - # If it's a screenshot action and we have a page, get the actual screenshot + # If it's a screenshot action if isinstance(block_input, dict) and block_input.get("action") == "screenshot": + # Render the screenshot data if available if screenshot: block_input["image"] = f"data:image/webp;base64,{screenshot}" else: block_input["image"] = "[Failed to get screenshot]" - compiled_response.append(f"**Action**: {json.dumps(block_input)}") + rendered_response.append(f"**Action**: {json.dumps(block_input)}") elif block.type == "thinking": thinking_content = getattr(block, "thinking", None) if thinking_content: - compiled_response.append(f"**Thought**: {thinking_content}") + rendered_response.append(f"**Thought**: {thinking_content}") - return "\n- ".join(filter(None, compiled_response)) + return "\n- ".join(filter(None, rendered_response)) # --- Main Operator Function --- @@ -1046,171 +1180,99 @@ async def operate_browser( cancellation_event: Optional[asyncio.Event] = None, tracer: dict = {}, ): - response, safety_check_message = None, None - final_compiled_response = "" + response, summary_message, user_input_message = None, None, None environment: Optional[BrowserEnvironment] = None + # Get the agent chat model + agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None + chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model) + supported_operator_model_types = [ChatModel.ModelType.OPENAI, ChatModel.ModelType.ANTHROPIC] + + if not chat_model or chat_model.model_type not in supported_operator_model_types: + raise ValueError( + f"Unsupported AI model. Configure and use chat model of type {supported_operator_model_types} to enable Browser use." + ) + + # Initialize Agent + max_iterations = 40 # TODO: Configurable? + operator_agent: OperatorAgent + if chat_model.model_type == ChatModel.ModelType.OPENAI: + operator_agent = OpenAIOperatorAgent(chat_model, max_iterations, tracer) + elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: + operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer) + else: # Should not happen due to check above, but satisfy type checker + raise ValueError("Invalid model type for operator agent.") + + # Initialize Environment + if send_status_func: + async for event in send_status_func(f"**Launching Browser**"): + yield {ChatEvent.STATUS: event} + environment = BrowserEnvironment() + await environment.start(width=1024, height=768) + + # Start Operator Loop try: - agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None - chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model) - supported_operator_model_types = [ChatModel.ModelType.OPENAI, ChatModel.ModelType.ANTHROPIC] - - if not chat_model or chat_model.model_type not in supported_operator_model_types: - raise ValueError( - f"Unsupported AI model. Configure and use chat model of type {supported_operator_model_types} to enable Browser use." - ) - - if send_status_func: - async for event in send_status_func(f"**Launching Browser**"): - yield {ChatEvent.STATUS: event} - - # Initialize Environment - environment = BrowserEnvironment() - await environment.start(width=1024, height=768) - - # Initialize Agent - max_iterations = 40 # TODO: Configurable? - operator_agent: OperatorAgent - if chat_model.model_type == ChatModel.ModelType.OPENAI: - operator_agent = OpenAIOperatorAgent(chat_model, max_iterations, tracer) - elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: - operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer) - else: # Should not happen due to check above, but satisfy type checker - raise ValueError("Invalid model type for operator agent.") - - messages = [{"role": "user", "content": query}] - run_summarize = False + summarize_prompt = ( + f"Collate all relevant information from your research so far to answer the target query:\n{query}." + ) task_completed = False iterations = 0 with timer(f"Operating browser with {chat_model.model_type} {chat_model.name}", logger): - while iterations < max_iterations: + while iterations < max_iterations and not task_completed: if cancellation_event and cancellation_event.is_set(): logger.info(f"Browser operator cancelled by client disconnect") break iterations += 1 - # Get current environment state + # 1. Get current environment state browser_state = await environment.get_state() - # Agent decides action(s) - agent_result = await operator_agent.act(deepcopy(messages), browser_state) - - final_compiled_response = agent_result.compiled_response # Update final response each turn - safety_check_message = agent_result.safety_check_message - - # Update conversation history with agent's response (before tool results) - if chat_model.model_type == ChatModel.ModelType.OPENAI: - # OpenAI expects list of blocks in 'content' for assistant message with tool calls - messages += agent_result.raw_agent_response.output - elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: - messages.append({"role": "assistant", "content": agent_result.raw_agent_response.content}) + # 2. Agent decides action(s) + agent_result = await operator_agent.act(query, browser_state) # Render status update - rendered_response = agent_result.compiled_response # Default rendering - if chat_model.model_type == ChatModel.ModelType.ANTHROPIC: - rendered_response = await operator_agent.render_response( - agent_result.raw_agent_response.content, browser_state.screenshot - ) - elif chat_model.model_type == ChatModel.ModelType.OPENAI: - rendered_response = await operator_agent.render_response( - agent_result.raw_agent_response.output, browser_state.screenshot - ) - if send_status_func: + rendered_response = agent_result.rendered_response + if send_status_func and rendered_response: async for event in send_status_func(f"**Operating Browser**:\n{rendered_response}"): yield {ChatEvent.STATUS: event} - # Execute actions in the environment - step_results: List[StepResult] = [] - if not safety_check_message: - for action in agent_result.actions: - if cancellation_event and cancellation_event.is_set(): - break - step_result = await environment.step(action) - step_results.append(step_result) - - # Gather results from actions - for step_result in step_results: - # Update the placeholder content in the history structure - result_for_history = agent_result.tool_results_for_history[len(step_results) - 1] - result_content = step_result.error or step_result.output or "[Action completed]" - if chat_model.model_type == ChatModel.ModelType.OPENAI: - if result_for_history["type"] == "computer_call_output": - result_for_history["output"] = { - "type": "input_image", - "image_url": f"data:image/webp;base64,{step_result.screenshot_base64}", - } - result_for_history["output"]["current_url"] = step_result.current_url - else: - result_for_history["output"] = result_content - elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: - result_for_history["content"] = result_content - if step_result.error: - result_for_history["is_error"] = True - - # Add browser message to compiled log for tracing - operator_agent.compiled_operator_messages.append( - ChatMessage(role="browser", content=str(result_content)) - ) - - # Check summarization conditions - summarize_prompt = ( - f"Collate all relevant information from your research so far to answer the target query:\n{query}." - ) - task_completed = not agent_result.actions and not run_summarize # No actions requested by agent - trigger_iteration_limit = iterations == max_iterations and not run_summarize + # 3. Execute actions in the environment + env_steps: List[EnvStepResult] = [] + for action in agent_result.actions: + if cancellation_event and cancellation_event.is_set(): + break + # Handle request for user action and break the loop + if isinstance(action, RequestUserAction): + user_input_message = action.request + if send_status_func: + async for event in send_status_func(f"**Requesting User Input**:\n{action.request}"): + yield {ChatEvent.STATUS: event} + break + env_step = await environment.step(action) + env_steps.append(env_step) + # Check if termination conditions are met + task_completed = not agent_result.actions # No actions requested by agent + trigger_iteration_limit = iterations == max_iterations if task_completed or trigger_iteration_limit: - iterations = max_iterations - 1 # Ensure one more iteration for summarization - run_summarize = True - logger.info( - f"Triggering summarization. Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}" - ) - - # Append summarize prompt differently based on model - if chat_model.model_type == ChatModel.ModelType.OPENAI: - # Pop the last tool result if max iterations reached and agent attempted a tool call - if trigger_iteration_limit and agent_result.tool_results_for_history: - agent_result.tool_results_for_history.pop() - - # Append summarize prompt as a user message after tool results - messages += agent_result.tool_results_for_history # Add results first - messages.append({"role": "user", "content": summarize_prompt}) - agent_result.tool_results_for_history = [] # Clear results as they are now in messages - - elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: - # Append summarize prompt as text within the tool results user message - agent_result.tool_results_for_history.append({"type": "text", "text": summarize_prompt}) - - # Add tool results to messages for the next iteration (if not handled above for OpenAI summarize) - if agent_result.tool_results_for_history: - if chat_model.model_type == ChatModel.ModelType.OPENAI: - messages += agent_result.tool_results_for_history - elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: - # Mark the final tool result as a cache break point - agent_result.tool_results_for_history[-1]["cache_control"] = {"type": "ephemeral"} - # Remove previous cache controls (Anthropic specific) - for msg in messages: - if msg["role"] == "user" and isinstance(msg["content"], list): - for block in msg["content"]: - if isinstance(block, dict) and "cache_control" in block: - del block["cache_control"] - messages.append({"role": "user", "content": agent_result.tool_results_for_history}) - - # Exit if safety checks are pending - if safety_check_message: - logger.warning(f"Safety check triggered: {safety_check_message}") + # Summarize results of operator run on last iteration + operator_agent.add_action_results(env_steps, agent_result, summarize_prompt) + summary_message = await operator_agent.summarize(query, browser_state) + logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}") break - # Determine final response message - if task_completed and not safety_check_message: - response = final_compiled_response - elif safety_check_message: - response = safety_check_message # Return safety message if that's why we stopped - else: # Hit iteration limit - response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{final_compiled_response}" + # 4. Update agent on the results of its action on the environment + operator_agent.add_action_results(env_steps, agent_result) + # Determine final response message + if user_input_message: + response = user_input_message + elif task_completed: + response = summary_message + else: # Hit iteration limit + response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}" except requests.RequestException as e: error_msg = f"Browser use failed due to a network error: {e}" logger.error(error_msg) @@ -1220,10 +1282,10 @@ async def operate_browser( logger.exception(error_msg) # Log full traceback for unexpected errors raise ValueError(error_msg) finally: - if environment and not safety_check_message: # Don't close browser if safety check pending + if environment and not user_input_message: # Don't close browser if user input required await environment.close() yield { - "text": safety_check_message or response, + "text": user_input_message or response, "webpages": [{"link": url, "snippet": ""} for url in environment.visited_urls], }