mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Simplify operator loop. Make each OperatorAgent manage state internally.
Remove each OperatorAgent specific code from leaking out into the
operator. The Oprator just calls the standard OperatorAgent functions.
Each AgentOperator specific logic is handled by the OperatorAgent
internally.
The improve the separation of responsibility between the operator,
OperatorAgent and the Environment.
- Make environment pass screenshot data in agent agnostic format
- Have operator agents providers format image data to their AI model
specific format
- Add environment step type to distinguish image vs text content
- Clearly mark major steps in the operator iteration loop
- Handle anthropic models returning computer tool actions as normal
tool calls by normalizing next action retrieval from response for it
- Remove unused ActionResults fields
- Remove unnnecessary placeholders to content of action results like
for screenshot data
This commit is contained in:
@@ -6,11 +6,10 @@ import os
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Callable, List, Literal, Optional, Set, Union
|
from typing import Any, Callable, List, Literal, Optional, Set, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from anthropic.types.beta import BetaContentBlock, BetaMessage
|
from anthropic.types.beta import BetaContentBlock, BetaMessage
|
||||||
from langchain.schema import ChatMessage
|
|
||||||
from openai.types.responses import Response, ResponseOutputItem
|
from openai.types.responses import Response, ResponseOutputItem
|
||||||
from playwright.async_api import Browser, Page, Playwright, async_playwright
|
from playwright.async_api import Browser, Page, Playwright, async_playwright
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -24,6 +23,7 @@ from khoj.utils.helpers import (
|
|||||||
get_anthropic_async_client,
|
get_anthropic_async_client,
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
get_openai_async_client,
|
get_openai_async_client,
|
||||||
|
is_none_or_empty,
|
||||||
is_promptrace_enabled,
|
is_promptrace_enabled,
|
||||||
timer,
|
timer,
|
||||||
)
|
)
|
||||||
@@ -69,7 +69,7 @@ class ScrollAction(BaseAction):
|
|||||||
scroll_x: Optional[int] = None
|
scroll_x: Optional[int] = None
|
||||||
scroll_y: Optional[int] = None
|
scroll_y: Optional[int] = None
|
||||||
scroll_direction: Optional[Literal["up", "down", "left", "right"]] = None
|
scroll_direction: Optional[Literal["up", "down", "left", "right"]] = None
|
||||||
scroll_amount: Optional[int] = 1
|
scroll_amount: Optional[int] = 5
|
||||||
|
|
||||||
|
|
||||||
class KeypressAction(BaseAction):
|
class KeypressAction(BaseAction):
|
||||||
@@ -131,6 +131,13 @@ class BackAction(BaseAction):
|
|||||||
type: Literal["back"] = "back"
|
type: Literal["back"] = "back"
|
||||||
|
|
||||||
|
|
||||||
|
class RequestUserAction(BaseAction):
|
||||||
|
"""Request user action to confirm or provide input."""
|
||||||
|
|
||||||
|
type: Literal["request_user"] = "request_user"
|
||||||
|
request: str
|
||||||
|
|
||||||
|
|
||||||
BrowserAction = Union[
|
BrowserAction = Union[
|
||||||
ClickAction,
|
ClickAction,
|
||||||
DoubleClickAction,
|
DoubleClickAction,
|
||||||
@@ -156,20 +163,23 @@ class EnvState(BaseModel):
|
|||||||
screenshot: Optional[str] = None
|
screenshot: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class StepResult(BaseModel):
|
class EnvStepResult(BaseModel):
|
||||||
output: Optional[str | List[dict]] = None # Output message or screenshot data
|
type: Literal["text", "image"] = "text"
|
||||||
|
output: Optional[str | dict] = None
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
current_url: Optional[str] = None
|
current_url: Optional[str] = None
|
||||||
screenshot_base64: Optional[str] = None
|
screenshot_base64: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class AgentActResult(BaseModel):
|
class AgentActResult(BaseModel):
|
||||||
compiled_response: str
|
|
||||||
actions: List[BrowserAction] = []
|
actions: List[BrowserAction] = []
|
||||||
tool_results_for_history: List[dict] = [] # Model-specific format for history
|
action_results: List[dict] = [] # Model-specific format
|
||||||
raw_agent_response: BetaMessage | Response # Store the raw response for history formatting
|
rendered_response: Optional[str] = None
|
||||||
usage: dict = {}
|
|
||||||
safety_check_message: Optional[str] = None
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: Literal["user", "assistant", "system", "environment"]
|
||||||
|
content: Union[str, List]
|
||||||
|
|
||||||
|
|
||||||
# --- Abstract Classes ---
|
# --- Abstract Classes ---
|
||||||
@@ -179,7 +189,7 @@ class Environment(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def step(self, action: BrowserAction) -> StepResult:
|
async def step(self, action: BrowserAction) -> EnvStepResult:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -196,10 +206,36 @@ class OperatorAgent(ABC):
|
|||||||
self.chat_model = chat_model
|
self.chat_model = chat_model
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
self.tracer = tracer
|
self.tracer = tracer
|
||||||
self.compiled_operator_messages: List[ChatMessage] = []
|
self.messages: List[ChatMessage] = []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def act(self, messages: List[dict], current_state: EnvState) -> AgentActResult:
|
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_action_results(
|
||||||
|
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
|
||||||
|
) -> None:
|
||||||
|
"""Track results of agent actions on the environment."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def summarize(self, query: str, current_state: EnvState) -> str:
|
||||||
|
"""Summarize the agent's actions and results."""
|
||||||
|
await self.act(query, current_state)
|
||||||
|
if not self.messages:
|
||||||
|
return "No actions to summarize."
|
||||||
|
return await self.compile_response(self.messages[-1].content)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compile_response(self, response: List) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _render_response(self, response: List, screenshot: Optional[str]) -> Optional[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _format_message_for_api(self, message: ChatMessage) -> List:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _update_usage(self, input_tokens: int, output_tokens: int, cache_read: int = 0, cache_write: int = 0):
|
def _update_usage(self, input_tokens: int, output_tokens: int, cache_read: int = 0, cache_write: int = 0):
|
||||||
@@ -210,10 +246,15 @@ class OperatorAgent(ABC):
|
|||||||
|
|
||||||
def _commit_trace(self):
|
def _commit_trace(self):
|
||||||
self.tracer["chat_model"] = self.chat_model.name
|
self.tracer["chat_model"] = self.chat_model.name
|
||||||
if is_promptrace_enabled() and self.compiled_operator_messages:
|
if is_promptrace_enabled() and len(self.messages) > 1:
|
||||||
commit_conversation_trace(
|
compiled_messages = [
|
||||||
self.compiled_operator_messages[:-1], self.compiled_operator_messages[-1].content, self.tracer
|
ChatMessage(role=msg.role, content=self.compile_response(msg.content)) for msg in self.messages
|
||||||
)
|
]
|
||||||
|
commit_conversation_trace(compiled_messages[:-1], compiled_messages[-1].content, self.tracer)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the agent state."""
|
||||||
|
self.messages = []
|
||||||
|
|
||||||
|
|
||||||
# --- Concrete BrowserEnvironment ---
|
# --- Concrete BrowserEnvironment ---
|
||||||
@@ -293,11 +334,12 @@ class BrowserEnvironment(Environment):
|
|||||||
screenshot = await self._get_screenshot()
|
screenshot = await self._get_screenshot()
|
||||||
return EnvState(url=url, screenshot=screenshot)
|
return EnvState(url=url, screenshot=screenshot)
|
||||||
|
|
||||||
async def step(self, action: BrowserAction) -> StepResult:
|
async def step(self, action: BrowserAction) -> EnvStepResult:
|
||||||
if not self.page or self.page.is_closed():
|
if not self.page or self.page.is_closed():
|
||||||
return StepResult(error="Browser page is not available or closed.")
|
return EnvStepResult(error="Browser page is not available or closed.")
|
||||||
|
|
||||||
output, error = None, None
|
state = await self.get_state()
|
||||||
|
output, error, step_type = None, None, "text"
|
||||||
try:
|
try:
|
||||||
match action.type:
|
match action.type:
|
||||||
case "click":
|
case "click":
|
||||||
@@ -389,8 +431,8 @@ class BrowserEnvironment(Environment):
|
|||||||
logger.debug(f"Action: {action.type} for {duration}s")
|
logger.debug(f"Action: {action.type} for {duration}s")
|
||||||
|
|
||||||
case "screenshot":
|
case "screenshot":
|
||||||
# Screenshot is taken after every step, so this action might just confirm it
|
step_type = "image"
|
||||||
output = "[Screenshot taken]"
|
output = {"image": state.screenshot, "url": state.url}
|
||||||
logger.debug(f"Action: {action.type}")
|
logger.debug(f"Action: {action.type}")
|
||||||
|
|
||||||
case "move":
|
case "move":
|
||||||
@@ -462,28 +504,17 @@ class BrowserEnvironment(Environment):
|
|||||||
error = f"Error executing action {action.type}: {e}"
|
error = f"Error executing action {action.type}: {e}"
|
||||||
logger.exception(f"Error during step execution for action: {action.model_dump_json()}")
|
logger.exception(f"Error during step execution for action: {action.model_dump_json()}")
|
||||||
|
|
||||||
state = await self.get_state()
|
return EnvStepResult(
|
||||||
|
type=step_type,
|
||||||
# Special handling for screenshot action result to include image data
|
|
||||||
if action.type == "screenshot" and state.screenshot:
|
|
||||||
output = [
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"source": {
|
|
||||||
"type": "base64",
|
|
||||||
"media_type": "image/webp",
|
|
||||||
"data": state.screenshot,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
return StepResult(
|
|
||||||
output=output,
|
output=output,
|
||||||
error=error,
|
error=error,
|
||||||
current_url=state.url,
|
current_url=state.url,
|
||||||
screenshot_base64=state.screenshot,
|
screenshot_base64=state.screenshot,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.visited_urls.clear()
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
if self.browser:
|
if self.browser:
|
||||||
await self.browser.close()
|
await self.browser.close()
|
||||||
@@ -542,14 +573,14 @@ class BrowserEnvironment(Environment):
|
|||||||
|
|
||||||
# --- Concrete Operator Agents ---
|
# --- Concrete Operator Agents ---
|
||||||
class OpenAIOperatorAgent(OperatorAgent):
|
class OpenAIOperatorAgent(OperatorAgent):
|
||||||
async def act(self, messages: List[dict], current_state: EnvState) -> AgentActResult:
|
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
||||||
client = get_openai_async_client(
|
client = get_openai_async_client(
|
||||||
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url
|
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url
|
||||||
)
|
)
|
||||||
safety_check_prefix = "The user needs to say 'continue' after resolving the following safety checks to proceed:"
|
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
|
||||||
safety_check_message = None
|
safety_check_message = None
|
||||||
actions: List[BrowserAction] = []
|
actions: List[BrowserAction] = []
|
||||||
tool_results_for_history: List[dict] = []
|
action_results: List[dict] = []
|
||||||
self._commit_trace() # Commit trace before next action
|
self._commit_trace() # Commit trace before next action
|
||||||
|
|
||||||
system_prompt = f"""<SYSTEM_CAPABILITY>
|
system_prompt = f"""<SYSTEM_CAPABILITY>
|
||||||
@@ -597,9 +628,13 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if is_none_or_empty(self.messages):
|
||||||
|
self.messages = [ChatMessage(role="user", content=query)]
|
||||||
|
|
||||||
|
messages_for_api = self._format_message_for_api(self.messages)
|
||||||
response: Response = await client.responses.create(
|
response: Response = await client.responses.create(
|
||||||
model="computer-use-preview",
|
model="computer-use-preview",
|
||||||
input=messages,
|
input=messages_for_api,
|
||||||
instructions=system_prompt,
|
instructions=system_prompt,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
parallel_tool_calls=False, # Keep sequential for now
|
parallel_tool_calls=False, # Keep sequential for now
|
||||||
@@ -608,14 +643,12 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Openai response: {response.model_dump_json()}")
|
logger.debug(f"Openai response: {response.model_dump_json()}")
|
||||||
compiled_response = self.compile_openai_response(response.output)
|
self.messages += [ChatMessage(role="environment", content=response.output)]
|
||||||
self.compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response))
|
rendered_response = await self._render_response(response.output, current_state.screenshot)
|
||||||
|
|
||||||
last_call_id = None
|
last_call_id = None
|
||||||
for block in response.output:
|
for block in response.output:
|
||||||
action_to_run: Optional[BrowserAction] = None
|
action_to_run: Optional[BrowserAction] = None
|
||||||
content_for_history: Optional[Union[str, dict]] = None
|
|
||||||
|
|
||||||
if block.type == "function_call":
|
if block.type == "function_call":
|
||||||
last_call_id = block.call_id
|
last_call_id = block.call_id
|
||||||
if block.name == "goto":
|
if block.name == "goto":
|
||||||
@@ -624,31 +657,24 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
url = args.get("url")
|
url = args.get("url")
|
||||||
if url:
|
if url:
|
||||||
action_to_run = GotoAction(url=url)
|
action_to_run = GotoAction(url=url)
|
||||||
content_for_history = (
|
|
||||||
f"Navigated to {url}" # Placeholder, actual result comes from env.step
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.warning("Goto function called without URL argument.")
|
logger.warning("Goto function called without URL argument.")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"Failed to parse arguments for goto: {block.arguments}")
|
logger.warning(f"Failed to parse arguments for goto: {block.arguments}")
|
||||||
elif block.name == "back":
|
elif block.name == "back":
|
||||||
action_to_run = BackAction()
|
action_to_run = BackAction()
|
||||||
content_for_history = "Navigated back" # Placeholder
|
|
||||||
|
|
||||||
elif block.type == "computer_call":
|
elif block.type == "computer_call":
|
||||||
last_call_id = block.call_id
|
last_call_id = block.call_id
|
||||||
if block.pending_safety_checks:
|
if block.pending_safety_checks:
|
||||||
for check in block.pending_safety_checks:
|
safety_check_body = "\n- ".join([check.message for check in block.pending_safety_checks])
|
||||||
if safety_check_message:
|
safety_check_message = f"{safety_check_prefix}\n- {safety_check_body}"
|
||||||
safety_check_message += f"\n- {check.message}"
|
action_to_run = RequestUserAction(request=safety_check_message)
|
||||||
else:
|
actions.append(action_to_run)
|
||||||
safety_check_message = f"{safety_check_prefix}\n- {check.message}"
|
|
||||||
break # Stop processing actions if safety check needed
|
break # Stop processing actions if safety check needed
|
||||||
|
|
||||||
openai_action = block.action
|
|
||||||
content_for_history = "[placeholder for screenshot]" # Placeholder
|
|
||||||
|
|
||||||
# Convert OpenAI action to standardized BrowserAction
|
# Convert OpenAI action to standardized BrowserAction
|
||||||
|
openai_action = block.action
|
||||||
action_type = openai_action.type
|
action_type = openai_action.type
|
||||||
try:
|
try:
|
||||||
if action_type == "click":
|
if action_type == "click":
|
||||||
@@ -682,10 +708,10 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
if action_to_run:
|
if action_to_run:
|
||||||
actions.append(action_to_run)
|
actions.append(action_to_run)
|
||||||
# Prepare the result structure expected in the message history
|
# Prepare the result structure expected in the message history
|
||||||
tool_results_for_history.append(
|
action_results.append(
|
||||||
{
|
{
|
||||||
"type": f"{block.type}_output",
|
"type": f"{block.type}_output",
|
||||||
"output": content_for_history, # This will be updated after env.step
|
"output": None, # Updated by environment step
|
||||||
"call_id": last_call_id,
|
"call_id": last_call_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -693,22 +719,84 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
self._update_usage(response.usage.input_tokens, response.usage.output_tokens)
|
self._update_usage(response.usage.input_tokens, response.usage.output_tokens)
|
||||||
|
|
||||||
return AgentActResult(
|
return AgentActResult(
|
||||||
compiled_response=compiled_response,
|
|
||||||
actions=actions,
|
actions=actions,
|
||||||
tool_results_for_history=tool_results_for_history,
|
action_results=action_results,
|
||||||
raw_agent_response=response,
|
rendered_response=rendered_response,
|
||||||
usage=self.tracer.get("usage", {}),
|
|
||||||
safety_check_message=safety_check_message,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def add_action_results(
|
||||||
|
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
|
||||||
|
) -> None:
|
||||||
|
if not agent_action.action_results and not summarize_prompt:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Update action results with results of applying suggested actions on the environment
|
||||||
|
for idx, env_step in enumerate(env_steps):
|
||||||
|
action_result = agent_action.action_results[idx]
|
||||||
|
result_content = env_step.error or env_step.output or "[Action completed]"
|
||||||
|
if env_step.type == "image":
|
||||||
|
# Add screenshot data in openai message format
|
||||||
|
action_result["output"] = {
|
||||||
|
"type": "input_image",
|
||||||
|
"image_url": f'data:image/webp;base64,{result_content["image"]}',
|
||||||
|
"current_url": result_content["url"],
|
||||||
|
}
|
||||||
|
elif action_result["type"] == "computer_call_output" and idx == len(env_steps) - 1:
|
||||||
|
# Always add screenshot, current url to last action result, when computer tool used
|
||||||
|
action_result["output"] = {
|
||||||
|
"type": "input_image",
|
||||||
|
"image_url": f"data:image/webp;base64,{env_step.screenshot_base64}",
|
||||||
|
"current_url": env_step.current_url,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Add text data
|
||||||
|
action_result["output"] = result_content
|
||||||
|
|
||||||
|
if agent_action.action_results:
|
||||||
|
self.messages += [ChatMessage(role="environment", content=agent_action.action_results)]
|
||||||
|
# Append summarize prompt as a user message after tool results
|
||||||
|
if summarize_prompt:
|
||||||
|
self.messages += [ChatMessage(role="user", content=summarize_prompt)]
|
||||||
|
|
||||||
|
def _format_message_for_api(self, messages: list[ChatMessage]) -> list:
|
||||||
|
"""Format the message for OpenAI API."""
|
||||||
|
formatted_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if message.role == "environment":
|
||||||
|
formatted_messages.extend(message.content)
|
||||||
|
else:
|
||||||
|
formatted_messages.append(
|
||||||
|
{
|
||||||
|
"role": message.role,
|
||||||
|
"content": message.content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return formatted_messages
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compile_openai_response(response_content: list[ResponseOutputItem]) -> str:
|
def compile_response(response_content: str | list[dict | ResponseOutputItem]) -> str:
|
||||||
"""Compile the response from Open AI model into a single string."""
|
"""Compile the response from model into a single string."""
|
||||||
|
# Handle case where response content is a string.
|
||||||
|
# This is the case when response content is a user query
|
||||||
|
if is_none_or_empty(response_content) or isinstance(response_content, str):
|
||||||
|
return response_content
|
||||||
|
# Handle case where response_content is a dictionary and not ResponseOutputItem
|
||||||
|
# This is the case when response_content contains action results
|
||||||
|
if not hasattr(response_content[0], "type"):
|
||||||
|
return "**Action**: " + json.dumps(response_content[0]["output"])
|
||||||
|
|
||||||
compiled_response = [""]
|
compiled_response = [""]
|
||||||
for block in deepcopy(response_content):
|
for block in deepcopy(response_content):
|
||||||
if block.type == "message":
|
if block.type == "message":
|
||||||
# Extract text content if available
|
# Extract text content if available
|
||||||
text_content = block.text if hasattr(block, "text") else block.model_dump_json()
|
for content in block.content:
|
||||||
|
text_content = ""
|
||||||
|
if hasattr(content, "text"):
|
||||||
|
text_content += content.text
|
||||||
|
elif hasattr(content, "refusal"):
|
||||||
|
text_content += f"Refusal: {content.refusal}"
|
||||||
|
else:
|
||||||
|
text_content += content.model_dump_json()
|
||||||
compiled_response.append(text_content)
|
compiled_response.append(text_content)
|
||||||
elif block.type == "function_call":
|
elif block.type == "function_call":
|
||||||
block_input = {"action": block.name}
|
block_input = {"action": block.name}
|
||||||
@@ -721,8 +809,9 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
compiled_response.append(f"**Action**: {json.dumps(block_input)}")
|
compiled_response.append(f"**Action**: {json.dumps(block_input)}")
|
||||||
elif block.type == "computer_call":
|
elif block.type == "computer_call":
|
||||||
block_input = block.action
|
block_input = block.action
|
||||||
# If it's a screenshot action and we have a screenshot, render it
|
# If it's a screenshot action
|
||||||
if block_input.type == "screenshot":
|
if block_input.type == "screenshot":
|
||||||
|
# Use a placeholder for screenshot data
|
||||||
block_input_render = block_input.model_dump()
|
block_input_render = block_input.model_dump()
|
||||||
block_input_render["image"] = "[placeholder for screenshot]"
|
block_input_render["image"] = "[placeholder for screenshot]"
|
||||||
compiled_response.append(f"**Action**: {json.dumps(block_input_render)}")
|
compiled_response.append(f"**Action**: {json.dumps(block_input_render)}")
|
||||||
@@ -733,13 +822,13 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
|
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def render_response(response_content: list[ResponseOutputItem], screenshot: Optional[str] = None) -> str:
|
async def _render_response(response_content: list[ResponseOutputItem], screenshot: Optional[str] = None) -> str:
|
||||||
"""Render OpenAI response for display, potentially including screenshots."""
|
"""Render OpenAI response for display, potentially including screenshots."""
|
||||||
compiled_response = [""]
|
rendered_response = [""]
|
||||||
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
|
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
|
||||||
if block.type == "message":
|
if block.type == "message":
|
||||||
text_content = block.text if hasattr(block, "text") else block.model_dump_json()
|
text_content = block.text if hasattr(block, "text") else block.model_dump_json()
|
||||||
compiled_response.append(text_content)
|
rendered_response.append(text_content)
|
||||||
elif block.type == "function_call":
|
elif block.type == "function_call":
|
||||||
block_input = {"action": block.name}
|
block_input = {"action": block.name}
|
||||||
if block.name == "goto":
|
if block.name == "goto":
|
||||||
@@ -748,26 +837,27 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
block_input["url"] = args.get("url", "[Missing URL]")
|
block_input["url"] = args.get("url", "[Missing URL]")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
block_input["arguments"] = block.arguments
|
block_input["arguments"] = block.arguments
|
||||||
compiled_response.append(f"**Action**: {json.dumps(block_input)}")
|
rendered_response.append(f"**Action**: {json.dumps(block_input)}")
|
||||||
elif block.type == "computer_call":
|
elif block.type == "computer_call":
|
||||||
block_input = block.action
|
block_input = block.action
|
||||||
# If it's a screenshot action and we have a screenshot, render it
|
# If it's a screenshot action
|
||||||
if block_input.type == "screenshot":
|
if block_input.type == "screenshot":
|
||||||
|
# Render screenshot if available
|
||||||
block_input_render = block_input.model_dump()
|
block_input_render = block_input.model_dump()
|
||||||
if screenshot:
|
if screenshot:
|
||||||
block_input_render["image"] = f"data:image/webp;base64,{screenshot}"
|
block_input_render["image"] = f"data:image/webp;base64,{screenshot}"
|
||||||
else:
|
else:
|
||||||
block_input_render["image"] = "[Failed to get screenshot]"
|
block_input_render["image"] = "[Failed to get screenshot]"
|
||||||
compiled_response.append(f"**Action**: {json.dumps(block_input_render)}")
|
rendered_response.append(f"**Action**: {json.dumps(block_input_render)}")
|
||||||
else:
|
else:
|
||||||
compiled_response.append(f"**Action**: {block_input.model_dump_json()}")
|
rendered_response.append(f"**Action**: {block_input.model_dump_json()}")
|
||||||
elif block.type == "reasoning" and block.summary:
|
elif block.type == "reasoning" and block.summary:
|
||||||
compiled_response.append(f"**Thought**: {block.summary}")
|
rendered_response.append(f"**Thought**: {block.summary}")
|
||||||
return "\n- ".join(filter(None, compiled_response))
|
return "\n- ".join(filter(None, rendered_response))
|
||||||
|
|
||||||
|
|
||||||
class AnthropicOperatorAgent(OperatorAgent):
|
class AnthropicOperatorAgent(OperatorAgent):
|
||||||
async def act(self, messages: List[dict], current_state: EnvState) -> AgentActResult:
|
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
||||||
client = get_anthropic_async_client(
|
client = get_anthropic_async_client(
|
||||||
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url
|
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url
|
||||||
)
|
)
|
||||||
@@ -775,7 +865,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"]
|
betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"]
|
||||||
temperature = 1.0
|
temperature = 1.0
|
||||||
actions: List[BrowserAction] = []
|
actions: List[BrowserAction] = []
|
||||||
tool_results_for_history: List[dict] = []
|
action_results: List[dict] = []
|
||||||
self._commit_trace() # Commit trace before next action
|
self._commit_trace() # Commit trace before next action
|
||||||
|
|
||||||
system_prompt = f"""<SYSTEM_CAPABILITY>
|
system_prompt = f"""<SYSTEM_CAPABILITY>
|
||||||
@@ -797,11 +887,8 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
|
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
|
||||||
</IMPORTANT>
|
</IMPORTANT>
|
||||||
"""
|
"""
|
||||||
# Add latest screenshot if available
|
if is_none_or_empty(self.messages):
|
||||||
if current_state.screenshot:
|
self.messages = [ChatMessage(role="user", content=query)]
|
||||||
# Ensure last message content is a list
|
|
||||||
if not isinstance(messages[-1]["content"], list):
|
|
||||||
messages[-1]["content"] = [{"type": "text", "text": messages[-1]["content"]}]
|
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
@@ -830,8 +917,9 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
if self.chat_model.name.startswith("claude-3-7"):
|
if self.chat_model.name.startswith("claude-3-7"):
|
||||||
thinking = {"type": "enabled", "budget_tokens": 1024}
|
thinking = {"type": "enabled", "budget_tokens": 1024}
|
||||||
|
|
||||||
|
messages_for_api = self._format_message_for_api(self.messages)
|
||||||
response = await client.beta.messages.create(
|
response = await client.beta.messages.create(
|
||||||
messages=messages,
|
messages=messages_for_api,
|
||||||
model=self.chat_model.name,
|
model=self.chat_model.name,
|
||||||
system=system_prompt,
|
system=system_prompt,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@@ -842,120 +930,102 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Anthropic response: {response.model_dump_json()}")
|
logger.debug(f"Anthropic response: {response.model_dump_json()}")
|
||||||
compiled_response = self.compile_response(response.content)
|
self.messages.append(ChatMessage(role="assistant", content=response.content))
|
||||||
self.compiled_operator_messages.append(
|
rendered_response = await self._render_response(response.content, current_state.screenshot)
|
||||||
ChatMessage(role="assistant", content=compiled_response)
|
|
||||||
) # Add raw response text
|
|
||||||
|
|
||||||
for block in response.content:
|
for block in response.content:
|
||||||
if block.type == "tool_use":
|
if block.type == "tool_use":
|
||||||
action_to_run: Optional[BrowserAction] = None
|
action_to_run: Optional[BrowserAction] = None
|
||||||
tool_input = block.input
|
tool_input = block.input
|
||||||
tool_name = block.name
|
tool_name = block.input.get("action") if block.name == "computer" else block.name
|
||||||
tool_use_id = block.id
|
tool_use_id = block.id
|
||||||
content_for_history: Optional[Union[str, List[dict]]] = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if tool_name == "computer":
|
if tool_name == "mouse_move":
|
||||||
action_type = tool_input.get("action")
|
|
||||||
content_for_history = "[placeholder for screenshot]" # Default placeholder
|
|
||||||
if action_type == "mouse_move":
|
|
||||||
coord = tool_input.get("coordinate")
|
coord = tool_input.get("coordinate")
|
||||||
if coord:
|
if coord:
|
||||||
action_to_run = MoveAction(x=coord[0], y=coord[1])
|
action_to_run = MoveAction(x=coord[0], y=coord[1])
|
||||||
elif action_type == "left_click":
|
elif tool_name == "left_click":
|
||||||
coord = tool_input.get("coordinate")
|
coord = tool_input.get("coordinate")
|
||||||
if coord:
|
if coord:
|
||||||
action_to_run = ClickAction(
|
action_to_run = ClickAction(
|
||||||
x=coord[0], y=coord[1], button="left", modifier=tool_input.get("text")
|
x=coord[0], y=coord[1], button="left", modifier=tool_input.get("text")
|
||||||
)
|
)
|
||||||
elif action_type == "right_click":
|
elif tool_name == "right_click":
|
||||||
coord = tool_input.get("coordinate")
|
coord = tool_input.get("coordinate")
|
||||||
if coord:
|
if coord:
|
||||||
action_to_run = ClickAction(x=coord[0], y=coord[1], button="right")
|
action_to_run = ClickAction(x=coord[0], y=coord[1], button="right")
|
||||||
elif action_type == "middle_click":
|
elif tool_name == "middle_click":
|
||||||
coord = tool_input.get("coordinate")
|
coord = tool_input.get("coordinate")
|
||||||
if coord:
|
if coord:
|
||||||
action_to_run = ClickAction(x=coord[0], y=coord[1], button="middle")
|
action_to_run = ClickAction(x=coord[0], y=coord[1], button="middle")
|
||||||
elif action_type == "double_click":
|
elif tool_name == "double_click":
|
||||||
coord = tool_input.get("coordinate")
|
coord = tool_input.get("coordinate")
|
||||||
if coord:
|
if coord:
|
||||||
action_to_run = DoubleClickAction(x=coord[0], y=coord[1])
|
action_to_run = DoubleClickAction(x=coord[0], y=coord[1])
|
||||||
elif action_type == "triple_click":
|
elif tool_name == "triple_click":
|
||||||
coord = tool_input.get("coordinate")
|
coord = tool_input.get("coordinate")
|
||||||
if coord:
|
if coord:
|
||||||
action_to_run = TripleClickAction(x=coord[0], y=coord[1])
|
action_to_run = TripleClickAction(x=coord[0], y=coord[1])
|
||||||
elif action_type == "left_click_drag":
|
elif tool_name == "left_click_drag":
|
||||||
start_coord = tool_input.get("start_coordinate")
|
start_coord = tool_input.get("start_coordinate")
|
||||||
end_coord = tool_input.get("coordinate")
|
end_coord = tool_input.get("coordinate")
|
||||||
if start_coord and end_coord:
|
if start_coord and end_coord:
|
||||||
action_to_run = DragAction(
|
action_to_run = DragAction(path=[Point(x=p[0], y=p[1]) for p in [start_coord, end_coord]])
|
||||||
path=[Point(x=p[0], y=p[1]) for p in [start_coord, end_coord]]
|
elif tool_name == "left_mouse_down":
|
||||||
)
|
|
||||||
elif action_type == "left_mouse_down":
|
|
||||||
action_to_run = MouseDownAction(button="left")
|
action_to_run = MouseDownAction(button="left")
|
||||||
elif action_type == "left_mouse_up":
|
elif tool_name == "left_mouse_up":
|
||||||
action_to_run = MouseUpAction(button="left")
|
action_to_run = MouseUpAction(button="left")
|
||||||
elif action_type == "type":
|
elif tool_name == "type":
|
||||||
text = tool_input.get("text")
|
text = tool_input.get("text")
|
||||||
if text:
|
if text:
|
||||||
action_to_run = TypeAction(text=text)
|
action_to_run = TypeAction(text=text)
|
||||||
elif action_type == "scroll":
|
elif tool_name == "scroll":
|
||||||
direction = tool_input.get("scroll_direction")
|
direction = tool_input.get("scroll_direction")
|
||||||
amount = tool_input.get("scroll_amount", 1)
|
amount = tool_input.get("scroll_amount", 5)
|
||||||
coord = tool_input.get("coordinate")
|
coord = tool_input.get("coordinate")
|
||||||
x = coord[0] if coord else None
|
x = coord[0] if coord else None
|
||||||
y = coord[1] if coord else None
|
y = coord[1] if coord else None
|
||||||
if direction:
|
if direction:
|
||||||
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y)
|
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y)
|
||||||
elif action_type == "key":
|
elif tool_name == "key":
|
||||||
text: str = tool_input.get("text")
|
text: str = tool_input.get("text")
|
||||||
if text:
|
if text:
|
||||||
action_to_run = KeypressAction(keys=text.split("+")) # Split xdotool style
|
action_to_run = KeypressAction(keys=text.split("+")) # Split xdotool style
|
||||||
elif action_type == "hold_key":
|
elif tool_name == "hold_key":
|
||||||
text = tool_input.get("text")
|
text = tool_input.get("text")
|
||||||
duration = tool_input.get("duration", 1.0)
|
duration = tool_input.get("duration", 1.0)
|
||||||
if text:
|
if text:
|
||||||
action_to_run = HoldKeyAction(text=text, duration=duration)
|
action_to_run = HoldKeyAction(text=text, duration=duration)
|
||||||
elif action_type == "wait":
|
elif tool_name == "wait":
|
||||||
duration = tool_input.get("duration", 1.0)
|
duration = tool_input.get("duration", 1.0)
|
||||||
action_to_run = WaitAction(duration=duration)
|
action_to_run = WaitAction(duration=duration)
|
||||||
elif action_type == "screenshot":
|
elif tool_name == "screenshot":
|
||||||
action_to_run = ScreenshotAction()
|
action_to_run = ScreenshotAction()
|
||||||
content_for_history = [
|
elif tool_name == "cursor_position":
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"source": {"type": "base64", "media_type": "image/webp", "data": "[placeholder]"},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
elif action_type == "cursor_position":
|
|
||||||
action_to_run = CursorPositionAction()
|
action_to_run = CursorPositionAction()
|
||||||
else:
|
|
||||||
logger.warning(f"Unsupported Anthropic computer action type: {action_type}")
|
|
||||||
|
|
||||||
elif tool_name == "goto":
|
elif tool_name == "goto":
|
||||||
url = tool_input.get("url")
|
url = tool_input.get("url")
|
||||||
if url:
|
if url:
|
||||||
action_to_run = GotoAction(url=url)
|
action_to_run = GotoAction(url=url)
|
||||||
content_for_history = f"Navigated to {url}"
|
|
||||||
else:
|
else:
|
||||||
logger.warning("Goto tool called without URL.")
|
logger.warning("Goto tool called without URL.")
|
||||||
elif tool_name == "back":
|
elif tool_name == "back":
|
||||||
action_to_run = BackAction()
|
action_to_run = BackAction()
|
||||||
content_for_history = "Navigated back"
|
else:
|
||||||
|
logger.warning(f"Unsupported Anthropic computer action type: {tool_name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error converting Anthropic action {tool_name} ({tool_input}): {e}")
|
logger.error(f"Error converting Anthropic action {tool_name} ({tool_input}): {e}")
|
||||||
|
|
||||||
if action_to_run:
|
if action_to_run:
|
||||||
actions.append(action_to_run)
|
actions.append(action_to_run)
|
||||||
# Prepare the result structure expected in the message history
|
action_results.append(
|
||||||
tool_results_for_history.append(
|
|
||||||
{
|
{
|
||||||
"type": "tool_result",
|
"type": "tool_result",
|
||||||
"tool_use_id": tool_use_id,
|
"tool_use_id": tool_use_id,
|
||||||
"content": content_for_history, # This will be updated after env.step
|
"content": None, # Updated by environment step
|
||||||
"is_error": False, # Will be updated after env.step
|
"is_error": False, # Updated by environment step
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -968,16 +1038,79 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
self.tracer["temperature"] = temperature
|
self.tracer["temperature"] = temperature
|
||||||
|
|
||||||
return AgentActResult(
|
return AgentActResult(
|
||||||
compiled_response=compiled_response,
|
|
||||||
actions=actions,
|
actions=actions,
|
||||||
tool_results_for_history=tool_results_for_history,
|
action_results=action_results,
|
||||||
raw_agent_response=response,
|
rendered_response=rendered_response,
|
||||||
usage=self.tracer.get("usage", {}),
|
|
||||||
safety_check_message=None, # Anthropic doesn't have this yet
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def compile_response(self, response_content: list[BetaContentBlock]) -> str:
|
def add_action_results(
|
||||||
|
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
|
||||||
|
):
|
||||||
|
if not agent_action.action_results and not summarize_prompt:
|
||||||
|
return
|
||||||
|
elif not agent_action.action_results:
|
||||||
|
agent_action.action_results = []
|
||||||
|
|
||||||
|
# Update action results with results of applying suggested actions on the environment
|
||||||
|
for idx, env_step in enumerate(env_steps):
|
||||||
|
action_result = agent_action.action_results[idx]
|
||||||
|
result_content = env_step.error or env_step.output or "[Action completed]"
|
||||||
|
if env_step.type == "image":
|
||||||
|
# Add screenshot data in anthropic message format
|
||||||
|
action_result["content"] = [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/webp",
|
||||||
|
"data": result_content["image"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Add text data
|
||||||
|
action_result["content"] = result_content
|
||||||
|
if env_step.error:
|
||||||
|
action_result["is_error"] = True
|
||||||
|
|
||||||
|
# If summarize prompt provided, append as text within the tool results user message
|
||||||
|
if summarize_prompt:
|
||||||
|
agent_action.action_results.append({"type": "text", "text": summarize_prompt})
|
||||||
|
|
||||||
|
# Append tool results to the message history
|
||||||
|
self.messages += [ChatMessage(role="environment", content=agent_action.action_results)]
|
||||||
|
|
||||||
|
# Mark the final tool result as a cache break point
|
||||||
|
agent_action.action_results[-1]["cache_control"] = {"type": "ephemeral"}
|
||||||
|
# Remove previous cache controls
|
||||||
|
for msg in self.messages:
|
||||||
|
if msg.role == "environment" and isinstance(msg.content, list):
|
||||||
|
for block in msg.content:
|
||||||
|
if isinstance(block, dict) and "cache_control" in block:
|
||||||
|
del block["cache_control"]
|
||||||
|
|
||||||
|
def _format_message_for_api(self, messages: list[ChatMessage]) -> list[dict]:
|
||||||
|
"""Format Anthropic response into a single string."""
|
||||||
|
formatted_messages = []
|
||||||
|
for message in messages:
|
||||||
|
role = "user" if message.role == "environment" else message.role
|
||||||
|
content = (
|
||||||
|
[{"type": "text", "text": message.content}]
|
||||||
|
if not isinstance(message.content, list)
|
||||||
|
else message.content
|
||||||
|
)
|
||||||
|
formatted_messages.append(
|
||||||
|
{
|
||||||
|
"role": role,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return formatted_messages
|
||||||
|
|
||||||
|
def compile_response(self, response_content: list[BetaContentBlock | Any]) -> str:
|
||||||
"""Compile Anthropic response into a single string."""
|
"""Compile Anthropic response into a single string."""
|
||||||
|
if is_none_or_empty(response_content) or not all(hasattr(item, "type") for item in response_content):
|
||||||
|
return response_content
|
||||||
compiled_response = [""]
|
compiled_response = [""]
|
||||||
for block in deepcopy(response_content):
|
for block in deepcopy(response_content):
|
||||||
if block.type == "text":
|
if block.type == "text":
|
||||||
@@ -1004,12 +1137,12 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
|
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def render_response(response_content: list[BetaContentBlock], screenshot: Optional[str] = None) -> str:
|
async def _render_response(response_content: list[BetaContentBlock], screenshot: Optional[str] = None) -> str:
|
||||||
"""Render Anthropic response, potentially including actual screenshots."""
|
"""Render Anthropic response, potentially including actual screenshots."""
|
||||||
compiled_response = [""]
|
rendered_response = [""]
|
||||||
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
|
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
|
||||||
if block.type == "text":
|
if block.type == "text":
|
||||||
compiled_response.append(block.text)
|
rendered_response.append(block.text)
|
||||||
elif block.type == "tool_use":
|
elif block.type == "tool_use":
|
||||||
block_input = {"action": block.name}
|
block_input = {"action": block.name}
|
||||||
if block.name == "computer":
|
if block.name == "computer":
|
||||||
@@ -1017,20 +1150,21 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
elif block.name == "goto":
|
elif block.name == "goto":
|
||||||
block_input["url"] = block.input.get("url", "[Missing URL]")
|
block_input["url"] = block.input.get("url", "[Missing URL]")
|
||||||
|
|
||||||
# If it's a screenshot action and we have a page, get the actual screenshot
|
# If it's a screenshot action
|
||||||
if isinstance(block_input, dict) and block_input.get("action") == "screenshot":
|
if isinstance(block_input, dict) and block_input.get("action") == "screenshot":
|
||||||
|
# Render the screenshot data if available
|
||||||
if screenshot:
|
if screenshot:
|
||||||
block_input["image"] = f"data:image/webp;base64,{screenshot}"
|
block_input["image"] = f"data:image/webp;base64,{screenshot}"
|
||||||
else:
|
else:
|
||||||
block_input["image"] = "[Failed to get screenshot]"
|
block_input["image"] = "[Failed to get screenshot]"
|
||||||
|
|
||||||
compiled_response.append(f"**Action**: {json.dumps(block_input)}")
|
rendered_response.append(f"**Action**: {json.dumps(block_input)}")
|
||||||
elif block.type == "thinking":
|
elif block.type == "thinking":
|
||||||
thinking_content = getattr(block, "thinking", None)
|
thinking_content = getattr(block, "thinking", None)
|
||||||
if thinking_content:
|
if thinking_content:
|
||||||
compiled_response.append(f"**Thought**: {thinking_content}")
|
rendered_response.append(f"**Thought**: {thinking_content}")
|
||||||
|
|
||||||
return "\n- ".join(filter(None, compiled_response))
|
return "\n- ".join(filter(None, rendered_response))
|
||||||
|
|
||||||
|
|
||||||
# --- Main Operator Function ---
|
# --- Main Operator Function ---
|
||||||
@@ -1046,11 +1180,10 @@ async def operate_browser(
|
|||||||
cancellation_event: Optional[asyncio.Event] = None,
|
cancellation_event: Optional[asyncio.Event] = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
response, safety_check_message = None, None
|
response, summary_message, user_input_message = None, None, None
|
||||||
final_compiled_response = ""
|
|
||||||
environment: Optional[BrowserEnvironment] = None
|
environment: Optional[BrowserEnvironment] = None
|
||||||
|
|
||||||
try:
|
# Get the agent chat model
|
||||||
agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None
|
agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None
|
||||||
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model)
|
chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model)
|
||||||
supported_operator_model_types = [ChatModel.ModelType.OPENAI, ChatModel.ModelType.ANTHROPIC]
|
supported_operator_model_types = [ChatModel.ModelType.OPENAI, ChatModel.ModelType.ANTHROPIC]
|
||||||
@@ -1060,14 +1193,6 @@ async def operate_browser(
|
|||||||
f"Unsupported AI model. Configure and use chat model of type {supported_operator_model_types} to enable Browser use."
|
f"Unsupported AI model. Configure and use chat model of type {supported_operator_model_types} to enable Browser use."
|
||||||
)
|
)
|
||||||
|
|
||||||
if send_status_func:
|
|
||||||
async for event in send_status_func(f"**Launching Browser**"):
|
|
||||||
yield {ChatEvent.STATUS: event}
|
|
||||||
|
|
||||||
# Initialize Environment
|
|
||||||
environment = BrowserEnvironment()
|
|
||||||
await environment.start(width=1024, height=768)
|
|
||||||
|
|
||||||
# Initialize Agent
|
# Initialize Agent
|
||||||
max_iterations = 40 # TODO: Configurable?
|
max_iterations = 40 # TODO: Configurable?
|
||||||
operator_agent: OperatorAgent
|
operator_agent: OperatorAgent
|
||||||
@@ -1078,139 +1203,76 @@ async def operate_browser(
|
|||||||
else: # Should not happen due to check above, but satisfy type checker
|
else: # Should not happen due to check above, but satisfy type checker
|
||||||
raise ValueError("Invalid model type for operator agent.")
|
raise ValueError("Invalid model type for operator agent.")
|
||||||
|
|
||||||
messages = [{"role": "user", "content": query}]
|
# Initialize Environment
|
||||||
run_summarize = False
|
if send_status_func:
|
||||||
|
async for event in send_status_func(f"**Launching Browser**"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
|
environment = BrowserEnvironment()
|
||||||
|
await environment.start(width=1024, height=768)
|
||||||
|
|
||||||
|
# Start Operator Loop
|
||||||
|
try:
|
||||||
|
summarize_prompt = (
|
||||||
|
f"Collate all relevant information from your research so far to answer the target query:\n{query}."
|
||||||
|
)
|
||||||
task_completed = False
|
task_completed = False
|
||||||
iterations = 0
|
iterations = 0
|
||||||
|
|
||||||
with timer(f"Operating browser with {chat_model.model_type} {chat_model.name}", logger):
|
with timer(f"Operating browser with {chat_model.model_type} {chat_model.name}", logger):
|
||||||
while iterations < max_iterations:
|
while iterations < max_iterations and not task_completed:
|
||||||
if cancellation_event and cancellation_event.is_set():
|
if cancellation_event and cancellation_event.is_set():
|
||||||
logger.info(f"Browser operator cancelled by client disconnect")
|
logger.info(f"Browser operator cancelled by client disconnect")
|
||||||
break
|
break
|
||||||
|
|
||||||
iterations += 1
|
iterations += 1
|
||||||
|
|
||||||
# Get current environment state
|
# 1. Get current environment state
|
||||||
browser_state = await environment.get_state()
|
browser_state = await environment.get_state()
|
||||||
|
|
||||||
# Agent decides action(s)
|
# 2. Agent decides action(s)
|
||||||
agent_result = await operator_agent.act(deepcopy(messages), browser_state)
|
agent_result = await operator_agent.act(query, browser_state)
|
||||||
|
|
||||||
final_compiled_response = agent_result.compiled_response # Update final response each turn
|
|
||||||
safety_check_message = agent_result.safety_check_message
|
|
||||||
|
|
||||||
# Update conversation history with agent's response (before tool results)
|
|
||||||
if chat_model.model_type == ChatModel.ModelType.OPENAI:
|
|
||||||
# OpenAI expects list of blocks in 'content' for assistant message with tool calls
|
|
||||||
messages += agent_result.raw_agent_response.output
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
|
||||||
messages.append({"role": "assistant", "content": agent_result.raw_agent_response.content})
|
|
||||||
|
|
||||||
# Render status update
|
# Render status update
|
||||||
rendered_response = agent_result.compiled_response # Default rendering
|
rendered_response = agent_result.rendered_response
|
||||||
if chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
if send_status_func and rendered_response:
|
||||||
rendered_response = await operator_agent.render_response(
|
|
||||||
agent_result.raw_agent_response.content, browser_state.screenshot
|
|
||||||
)
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
|
||||||
rendered_response = await operator_agent.render_response(
|
|
||||||
agent_result.raw_agent_response.output, browser_state.screenshot
|
|
||||||
)
|
|
||||||
if send_status_func:
|
|
||||||
async for event in send_status_func(f"**Operating Browser**:\n{rendered_response}"):
|
async for event in send_status_func(f"**Operating Browser**:\n{rendered_response}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
|
|
||||||
# Execute actions in the environment
|
# 3. Execute actions in the environment
|
||||||
step_results: List[StepResult] = []
|
env_steps: List[EnvStepResult] = []
|
||||||
if not safety_check_message:
|
|
||||||
for action in agent_result.actions:
|
for action in agent_result.actions:
|
||||||
if cancellation_event and cancellation_event.is_set():
|
if cancellation_event and cancellation_event.is_set():
|
||||||
break
|
break
|
||||||
step_result = await environment.step(action)
|
# Handle request for user action and break the loop
|
||||||
step_results.append(step_result)
|
if isinstance(action, RequestUserAction):
|
||||||
|
user_input_message = action.request
|
||||||
# Gather results from actions
|
if send_status_func:
|
||||||
for step_result in step_results:
|
async for event in send_status_func(f"**Requesting User Input**:\n{action.request}"):
|
||||||
# Update the placeholder content in the history structure
|
yield {ChatEvent.STATUS: event}
|
||||||
result_for_history = agent_result.tool_results_for_history[len(step_results) - 1]
|
break
|
||||||
result_content = step_result.error or step_result.output or "[Action completed]"
|
env_step = await environment.step(action)
|
||||||
if chat_model.model_type == ChatModel.ModelType.OPENAI:
|
env_steps.append(env_step)
|
||||||
if result_for_history["type"] == "computer_call_output":
|
|
||||||
result_for_history["output"] = {
|
|
||||||
"type": "input_image",
|
|
||||||
"image_url": f"data:image/webp;base64,{step_result.screenshot_base64}",
|
|
||||||
}
|
|
||||||
result_for_history["output"]["current_url"] = step_result.current_url
|
|
||||||
else:
|
|
||||||
result_for_history["output"] = result_content
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
|
||||||
result_for_history["content"] = result_content
|
|
||||||
if step_result.error:
|
|
||||||
result_for_history["is_error"] = True
|
|
||||||
|
|
||||||
# Add browser message to compiled log for tracing
|
|
||||||
operator_agent.compiled_operator_messages.append(
|
|
||||||
ChatMessage(role="browser", content=str(result_content))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check summarization conditions
|
|
||||||
summarize_prompt = (
|
|
||||||
f"Collate all relevant information from your research so far to answer the target query:\n{query}."
|
|
||||||
)
|
|
||||||
task_completed = not agent_result.actions and not run_summarize # No actions requested by agent
|
|
||||||
trigger_iteration_limit = iterations == max_iterations and not run_summarize
|
|
||||||
|
|
||||||
|
# Check if termination conditions are met
|
||||||
|
task_completed = not agent_result.actions # No actions requested by agent
|
||||||
|
trigger_iteration_limit = iterations == max_iterations
|
||||||
if task_completed or trigger_iteration_limit:
|
if task_completed or trigger_iteration_limit:
|
||||||
iterations = max_iterations - 1 # Ensure one more iteration for summarization
|
# Summarize results of operator run on last iteration
|
||||||
run_summarize = True
|
operator_agent.add_action_results(env_steps, agent_result, summarize_prompt)
|
||||||
logger.info(
|
summary_message = await operator_agent.summarize(query, browser_state)
|
||||||
f"Triggering summarization. Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}"
|
logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}")
|
||||||
)
|
|
||||||
|
|
||||||
# Append summarize prompt differently based on model
|
|
||||||
if chat_model.model_type == ChatModel.ModelType.OPENAI:
|
|
||||||
# Pop the last tool result if max iterations reached and agent attempted a tool call
|
|
||||||
if trigger_iteration_limit and agent_result.tool_results_for_history:
|
|
||||||
agent_result.tool_results_for_history.pop()
|
|
||||||
|
|
||||||
# Append summarize prompt as a user message after tool results
|
|
||||||
messages += agent_result.tool_results_for_history # Add results first
|
|
||||||
messages.append({"role": "user", "content": summarize_prompt})
|
|
||||||
agent_result.tool_results_for_history = [] # Clear results as they are now in messages
|
|
||||||
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
|
||||||
# Append summarize prompt as text within the tool results user message
|
|
||||||
agent_result.tool_results_for_history.append({"type": "text", "text": summarize_prompt})
|
|
||||||
|
|
||||||
# Add tool results to messages for the next iteration (if not handled above for OpenAI summarize)
|
|
||||||
if agent_result.tool_results_for_history:
|
|
||||||
if chat_model.model_type == ChatModel.ModelType.OPENAI:
|
|
||||||
messages += agent_result.tool_results_for_history
|
|
||||||
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
|
|
||||||
# Mark the final tool result as a cache break point
|
|
||||||
agent_result.tool_results_for_history[-1]["cache_control"] = {"type": "ephemeral"}
|
|
||||||
# Remove previous cache controls (Anthropic specific)
|
|
||||||
for msg in messages:
|
|
||||||
if msg["role"] == "user" and isinstance(msg["content"], list):
|
|
||||||
for block in msg["content"]:
|
|
||||||
if isinstance(block, dict) and "cache_control" in block:
|
|
||||||
del block["cache_control"]
|
|
||||||
messages.append({"role": "user", "content": agent_result.tool_results_for_history})
|
|
||||||
|
|
||||||
# Exit if safety checks are pending
|
|
||||||
if safety_check_message:
|
|
||||||
logger.warning(f"Safety check triggered: {safety_check_message}")
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# Determine final response message
|
# 4. Update agent on the results of its action on the environment
|
||||||
if task_completed and not safety_check_message:
|
operator_agent.add_action_results(env_steps, agent_result)
|
||||||
response = final_compiled_response
|
|
||||||
elif safety_check_message:
|
|
||||||
response = safety_check_message # Return safety message if that's why we stopped
|
|
||||||
else: # Hit iteration limit
|
|
||||||
response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{final_compiled_response}"
|
|
||||||
|
|
||||||
|
# Determine final response message
|
||||||
|
if user_input_message:
|
||||||
|
response = user_input_message
|
||||||
|
elif task_completed:
|
||||||
|
response = summary_message
|
||||||
|
else: # Hit iteration limit
|
||||||
|
response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}"
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
error_msg = f"Browser use failed due to a network error: {e}"
|
error_msg = f"Browser use failed due to a network error: {e}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
@@ -1220,10 +1282,10 @@ async def operate_browser(
|
|||||||
logger.exception(error_msg) # Log full traceback for unexpected errors
|
logger.exception(error_msg) # Log full traceback for unexpected errors
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
finally:
|
finally:
|
||||||
if environment and not safety_check_message: # Don't close browser if safety check pending
|
if environment and not user_input_message: # Don't close browser if user input required
|
||||||
await environment.close()
|
await environment.close()
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"text": safety_check_message or response,
|
"text": user_input_message or response,
|
||||||
"webpages": [{"link": url, "snippet": ""} for url in environment.visited_urls],
|
"webpages": [{"link": url, "snippet": ""} for url in environment.visited_urls],
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user