Pull out common iteration loop into main browser operator method

This commit is contained in:
Debanjum
2025-05-03 15:24:39 -06:00
parent 08e93c64ab
commit 7c60e04efb

View File

@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
async def operate_browser(
message: str,
query: str,
user: KhojUser,
conversation_log: dict,
location_data: LocationData,
@@ -56,7 +56,7 @@ async def operate_browser(
)
if send_status_func:
async for event in send_status_func(f"**Launching Browser**:\n{message}"):
async for event in send_status_func(f"**Launching Browser**:\n{query}"):
yield {ChatEvent.STATUS: event}
# Start the browser
@@ -65,50 +65,136 @@ async def operate_browser(
# Operate the browser
max_iterations = 40
max_tokens = 4096
compiled_operator_messages: List[ChatMessage] = []
run_summarize = False
task_completed = False
iterations = 0
messages = [{"role": "user", "content": query}]
final_compiled_response = ""
error_msg_template = "Browser use with {model_type} model failed due to an {error_type} error: {e}"
with timer(f"Operating browser with {chat_model.model_type} {chat_model.name}", logger):
try:
if chat_model.model_type == ChatModel.ModelType.OPENAI:
async for result in browser_use_openai(
message,
chat_model,
page,
width,
height,
max_iterations=max_iterations,
send_status_func=send_status_func,
user=user,
agent=agent,
cancellation_event=cancellation_event,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result
else:
response, safety_check_message = result
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
async for result in browser_use_anthropic(
message,
chat_model,
page,
width,
height,
max_iterations=max_iterations,
send_status_func=send_status_func,
user=user,
agent=agent,
cancellation_event=cancellation_event,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result
else:
response, safety_check_message = result
while iterations < max_iterations:
# Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set():
logger.info(f"Browser operator cancelled by client disconnect")
break
iterations += 1
tool_results = []
compiled_response = ""
if chat_model.model_type == ChatModel.ModelType.OPENAI:
(
agent_response,
compiled_response,
tool_results,
safety_check_message,
) = await _openai_iteration(
messages=messages,
chat_model=chat_model,
page=page,
width=width,
height=height,
max_tokens=max_tokens,
max_iterations=max_iterations,
compiled_operator_messages=compiled_operator_messages,
tracer=tracer,
)
messages += agent_response.output
rendered_response = compiled_response
final_compiled_response = compiled_response
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
(
agent_response,
compiled_response,
tool_results,
safety_check_message,
) = await _anthropic_iteration(
messages=messages,
chat_model=chat_model,
page=page,
width=width,
height=height,
max_tokens=max_tokens,
max_iterations=max_iterations,
compiled_operator_messages=compiled_operator_messages,
tracer=tracer,
)
messages.append({"role": "assistant", "content": agent_response.content})
rendered_response = await render_claude_response(agent_response.content, page)
final_compiled_response = compiled_response
if send_status_func:
async for event in send_status_func(f"**Operating Browser**:\n{rendered_response}"):
yield {ChatEvent.STATUS: event}
# Check summarization conditions
summarize_prompt = (
f"Collate all relevant information from your research so far to answer the target query:\n{query}."
)
task_completed = not tool_results and not run_summarize
trigger_iteration_limit = iterations == max_iterations and not run_summarize
if task_completed or trigger_iteration_limit:
iterations = max_iterations - 1 # Ensure one more iteration for summarization
run_summarize = True
# Model specific handling for appending the summarize prompt
if chat_model.model_type == ChatModel.ModelType.OPENAI:
# Pop the last tool result if max iterations reached and agent attempted a tool call
any_tool_calls = any(
block.type in ["computer_call", "function_call"] for block in agent_response.output
)
if trigger_iteration_limit and any_tool_calls and tool_results:
tool_results.pop() # Remove the action that couldn't be processed due to limit
# Append summarize prompt
tool_results.append({"role": "user", "content": summarize_prompt})
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
# No specific action needed for Anthropic on iteration limit besides setting task_completed = False
# Append summarize prompt
tool_results.append({"type": "text", "text": summarize_prompt})
# Add tool results to messages for the next iteration
if tool_results:
if chat_model.model_type == ChatModel.ModelType.OPENAI:
messages += tool_results
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
# Mark the final tool result as a cache break point
tool_results[-1]["cache_control"] = {"type": "ephemeral"}
# Remove all previous cache break points
for msg in messages:
if isinstance(msg["content"], list):
for tool_result in msg["content"]:
if isinstance(tool_result, dict) and "cache_control" in tool_result:
del tool_result["cache_control"]
elif isinstance(msg["content"], dict) and "cache_control" in msg["content"]:
del msg["content"]["cache_control"]
messages.append({"role": "user", "content": tool_results})
# Exit if safety checks are pending
if safety_check_message:
break
# Determine final response message
if task_completed:
response = final_compiled_response
else:
response = f"Operator hit iteration limit. If the results seems incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{final_compiled_response}"
except requests.RequestException as e:
raise ValueError(f"Browser use with {chat_model.model_type} model failed due to a network error: {e}")
error_msg = error_msg_template.format(model_type=chat_model.model_type, error_type="network", e=e)
raise ValueError(error_msg)
except Exception as e:
raise ValueError(f"Browser use with {chat_model.model_type} model failed due to an unknown error: {e}")
error_msg = error_msg_template.format(model_type=chat_model.model_type, error_type="unknown", e=e)
logger.exception(error_msg)
raise ValueError(error_msg)
finally:
# Keep the browser open if safety checks are pending to allow user to investigate and resolve them
if not safety_check_message:
# Close the browser
await browser.close()
@@ -152,33 +238,19 @@ async def start_browser(width: int = 1024, height: int = 768):
return playwright, browser, page
async def browser_use_openai(
query: str,
async def _openai_iteration(
messages: list,
chat_model: ChatModel,
page: Page,
width: int = 1024,
height: int = 768,
max_tokens: int = 4096,
max_iterations: int = 40,
send_status_func: Optional[Callable] = None,
user: KhojUser = None,
agent: Agent = None,
cancellation_event: Optional[asyncio.Event] = None,
width: int,
height: int,
max_tokens: int,
max_iterations: int,
compiled_operator_messages: List[ChatMessage],
tracer: dict = {},
):
"""
A simple agent loop for openai browser use interactions.
This function handles the back-and-forth between:
1. Sending user messages to Openai
2. Openai requesting to use browser use tools
3. Playwright executing those tools in browser.
4. Sending tool results back to Openai
"""
# Setup tools and API parameters
"""Performs one iteration of the OpenAI browser interaction loop."""
client = get_openai_async_client(chat_model.ai_model_api.api_key, chat_model.ai_model_api.api_base_url)
messages = [{"role": "user", "content": query}]
safety_check_prefix = "The user needs to say 'continue' after resolving the following safety checks to proceed:"
safety_check = None
system_prompt = f"""<SYSTEM_CAPABILITY>
@@ -227,170 +299,103 @@ async def browser_use_openai(
},
]
# Main agent loop (with iteration limit to prevent runaway API costs)
compiled_operator_messages: List[ChatMessage] = []
run_summarize = False
task_completed = False
response = await client.responses.create(
model="computer-use-preview",
input=messages,
instructions=system_prompt,
tools=tools,
parallel_tool_calls=False,
max_output_tokens=max_tokens,
truncation="auto",
)
logger.debug(f"Openai response: {response.model_dump_json()}")
compiled_response = compile_openai_response(response.output)
# Add Openai's response to the tracer conversation history
compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response))
# Check if any tools used
tool_call_blocks = [
block for block in response.output if block.type == "computer_call" or block.type == "function_call"
]
tool_results: list[dict[str, str | dict]] = []
last_call_id = None
iterations = 0
while iterations < max_iterations:
# Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set():
logger.info(f"Browser operator cancelled by client disconnect")
break
iterations += 1
# Send the screenshot back as a computer_call_output
response = await client.responses.create(
model="computer-use-preview",
input=messages,
instructions=system_prompt,
tools=tools,
parallel_tool_calls=False,
max_output_tokens=max_tokens,
truncation="auto",
)
logger.debug(f"Openai response: {response.model_dump_json()}")
messages += response.output
compiled_response = compile_openai_response(response.output)
# Add Openai's response to the tracer conversation history
compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response))
if send_status_func:
async for event in send_status_func(f"**Operating Browser**:\n{compiled_response}"):
yield {ChatEvent.STATUS: event}
# Check if any tools used
tool_call_blocks = [
block for block in response.output if block.type == "computer_call" or block.type == "function_call"
]
tool_results: list[dict[str, str | dict]] = []
# Run the tool calls in order
for block in tool_call_blocks:
block_input: ActionBack | ActionGoto | response_computer_tool_call.Action = None
# Run the tool calls in order
for block in tool_call_blocks:
if block.type == "function_call":
last_call_id = block.call_id
if hasattr(block, "name") and block.name == "goto":
url = json.loads(block.arguments).get("url")
block_input = ActionGoto(type="goto", url=url)
elif hasattr(block, "name") and block.name == "back":
block_input = ActionBack(type="back")
# if user doesn't ack all safety checks exit with error
elif block.type == "computer_call" and block.pending_safety_checks:
for check in block.pending_safety_checks:
if safety_check:
safety_check += f"\n- {check.message}"
else:
safety_check = f"{safety_check_prefix}\n- {check.message}"
break
elif block.type == "computer_call":
last_call_id = block.call_id
block_input = block.action
if block.type == "function_call":
last_call_id = block.call_id
if hasattr(block, "name") and block.name == "goto":
url = json.loads(block.arguments).get("url")
block_input = ActionGoto(type="goto", url=url)
elif hasattr(block, "name") and block.name == "back":
block_input = ActionBack(type="back")
# Exit tool processing if safety check needed
elif block.type == "computer_call" and block.pending_safety_checks:
for check in block.pending_safety_checks:
if safety_check:
safety_check += f"\n- {check.message}"
else:
safety_check = f"{safety_check_prefix}\n- {check.message}"
break
elif block.type == "computer_call":
last_call_id = block.call_id
block_input = block.action
result = await handle_browser_action_openai(page, block_input)
content_text = result.get("output") or result.get("error")
compiled_operator_messages.append(ChatMessage(role="browser", content=content_text))
result = await handle_browser_action_openai(page, block_input)
content_text = result.get("output") or result.get("error")
compiled_operator_messages.append(ChatMessage(role="browser", content=content_text))
# Take a screenshot after computer action
if block.type == "computer_call":
screenshot_base64 = await get_screenshot(page)
content = {"type": "input_image", "image_url": f"data:image/webp;base64,{screenshot_base64}"}
content["current_url"] = page.url if block.type == "computer_call" else None
elif block.type == "function_call":
content = content_text
# Take a screenshot after computer action
if block.type == "computer_call":
screenshot_base64 = await get_screenshot(page)
content = {"type": "input_image", "image_url": f"data:image/webp;base64,{screenshot_base64}"}
content["current_url"] = page.url if block.type == "computer_call" else None
elif block.type == "function_call":
content = content_text
# Format the tool call results
tool_results.append(
{
"type": f"{block.type}_output",
"output": content,
"call_id": last_call_id,
}
)
# Calculate cost of chat
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
tracer["usage"] = get_chat_usage_metrics(
chat_model.name, input_tokens, output_tokens, usage=tracer.get("usage")
# Format the tool call results
tool_results.append(
{
"type": f"{block.type}_output",
"output": content,
"call_id": last_call_id,
}
)
logger.debug(f"Operator usage by Openai: {tracer['usage']}")
# Save conversation trace
tracer["chat_model"] = chat_model.name
if is_promptrace_enabled():
commit_conversation_trace(compiled_operator_messages[:-1], compiled_operator_messages[-1].content, tracer)
# Calculate cost of chat
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
tracer["usage"] = get_chat_usage_metrics(chat_model.name, input_tokens, output_tokens, usage=tracer.get("usage"))
logger.debug(f"Operator usage by Openai: {tracer['usage']}")
# Run one last iteration to collate results of browser use, if no tool use requested or iteration limit reached.
if not tool_results and not run_summarize:
iterations = max_iterations - 1
run_summarize = True
task_completed = True
tool_results.append(
{
"type": "message",
"role": "user",
"content": f"Collate all relevant information from your research so far to answer the target query:\n{query}.",
}
)
elif iterations == max_iterations and not run_summarize:
iterations = max_iterations - 1
run_summarize = True
task_completed = not tool_results
# Pop the last tool call if max iterations reached
if tool_call_blocks:
tool_results.pop()
tool_results.append(
{
"type": "message",
"role": "user",
"content": f"Collate all relevant information from your research so far to answer the target query:\n{query}.",
}
)
# Save conversation trace
tracer["chat_model"] = chat_model.name
if is_promptrace_enabled():
commit_conversation_trace(compiled_operator_messages[:-1], compiled_operator_messages[-1].content, tracer)
# Add tool results to messages for the next iteration with Openai
messages += tool_results
if task_completed:
final_response = compiled_response
else:
final_response = f"Operator hit iteration limit. If the results seems incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{compiled_response}"
yield (final_response, safety_check)
return response, compiled_response, tool_results, safety_check
async def browser_use_anthropic(
query: str,
async def _anthropic_iteration(
messages: list,
chat_model: ChatModel,
page: Page,
width: int = 1024,
height: int = 768,
max_tokens: int = 4096,
width: int,
height: int,
max_tokens: int,
max_iterations: int,
compiled_operator_messages: List[ChatMessage],
thinking_budget: int | None = 1024,
max_iterations: int = 40, # Add iteration limit to prevent infinite loops
send_status_func: Optional[Callable] = None,
user: KhojUser = None,
agent: Agent = None,
cancellation_event: Optional[asyncio.Event] = None,
tracer: dict = {},
):
"""
A simple agent loop for Claude computer use interactions.
This function handles the back-and-forth between:
1. Sending user messages to Claude
2. Claude requesting to use browser use tools
3. Playwright executing those tools in browser.
4. Sending tool results back to Claude
"""
# Set up tools and API parameters
"""Performs one iteration of the Anthropic browser interaction loop."""
client = get_anthropic_async_client(chat_model.ai_model_api.api_key, chat_model.ai_model_api.api_base_url)
messages = [{"role": "user", "content": query}]
tool_version = "2025-01-24"
betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"]
temperature = 1.0
safety_check = None
safety_check = None # Anthropic doesn't have explicit safety checks in the same way yet
system_prompt = f"""<SYSTEM_CAPABILITY>
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
* You operate a Chromium browser using Playwright.
@@ -439,144 +444,104 @@ async def browser_use_anthropic(
# {"type": f"bash_20250124", "name": "bash"}
]
if agent:
agent.chat_model = chat_model
# Set up optional thinking parameter
thinking = {"type": "disabled"}
if chat_model.name.startswith("claude-3-7") and thinking_budget:
thinking = {"type": "enabled", "budget_tokens": thinking_budget}
# Main agent loop (with iteration limit to prevent runaway API costs)
compiled_operator_messages: List[ChatMessage] = []
run_summarize = False
task_completed = False
iterations = 0
while iterations < max_iterations:
# Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set():
logger.info(f"Browser operator cancelled by client disconnect")
break
# Call the Claude API
response = await client.beta.messages.create(
messages=messages,
model=chat_model.name,
system=system_prompt,
tools=tools,
betas=betas,
thinking=thinking,
max_tokens=max_tokens,
temperature=temperature,
)
iterations += 1
# Set up optional thinking parameter (for Claude 3.7 Sonnet)
thinking = {"type": "disabled"}
if chat_model.name.startswith("claude-3-7") and thinking_budget:
thinking = {"type": "enabled", "budget_tokens": thinking_budget}
response_content = response.content
compiled_response = compile_claude_response(response_content)
# Call the Claude API
response = await client.beta.messages.create(
messages=messages,
model=chat_model.name,
system=system_prompt,
tools=tools,
betas=betas,
thinking=thinking,
max_tokens=max_tokens,
temperature=temperature,
)
# Add Claude's response to the conversation history
compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response))
logger.debug(f"Claude response: {response.model_dump_json()}")
response_content = response.content
compiled_response = compile_claude_response(response_content)
# Check if Claude used any tools
tool_results = []
for block in response_content:
if block.type == "tool_use":
if hasattr(block, "name") and block.name == "goto":
block_input = {"action": block.name, "url": block.input.get("url")}
elif hasattr(block, "name") and block.name == "back":
block_input = {"action": block.name}
else:
block_input = block.input
result = await handle_browser_action_anthropic(page, block_input)
content = result.get("output") or result.get("error")
if isinstance(content, str):
compiled_operator_messages.append(ChatMessage(role="browser", content=content))
elif isinstance(content, list) and content and content[0]["type"] == "image":
compiled_operator_messages.append(ChatMessage(role="browser", content="[placeholder for screenshot]"))
# Add Claude's response to the conversation history
messages.append({"role": "assistant", "content": response_content})
compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response))
logger.debug(f"Claude response: {response.model_dump_json()}")
if send_status_func:
rendered_response = await render_claude_response(response_content, page)
async for event in send_status_func(f"**Operating Browser**:\n{rendered_response}"):
yield {ChatEvent.STATUS: event}
# Check if Claude used any tools
tool_results = []
for block in response_content:
if block.type == "tool_use":
if hasattr(block, "name") and block.name == "goto":
block_input = {"action": block.name, "url": block.input.get("url")}
elif hasattr(block, "name") and block.name == "back":
block_input = {"action": block.name}
else:
block_input = block.input
result = await handle_browser_action_anthropic(page, block_input)
content = result.get("output") or result.get("error")
if isinstance(content, str):
compiled_operator_messages.append(ChatMessage(role="browser", content=content))
elif isinstance(content, list) and content[0]["type"] == "image":
# Handle the case where the content is an image
compiled_operator_messages.append(
ChatMessage(role="browser", content="[placeholder for screenshot]")
)
# Format the result for Claude
tool_results.append(
{
"type": "tool_result",
"tool_use_id": block.id,
"content": content,
"is_error": "error" in result,
}
)
# Calculate cost of chat
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
cache_read_tokens = response.usage.cache_read_input_tokens
cache_write_tokens = response.usage.cache_creation_input_tokens
tracer["usage"] = get_chat_usage_metrics(
chat_model.name,
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
usage=tracer.get("usage"),
)
logger.debug(f"Operator usage by Claude: {tracer['usage']}")
# Save conversation trace
tracer["chat_model"] = chat_model.name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(compiled_operator_messages[:-1], compiled_operator_messages[-1].content, tracer)
# Run one last iteration to collate results of browser use, if no tool use requested or iteration limit reached.
if not tool_results and not run_summarize:
iterations = max_iterations - 1
run_summarize = True
task_completed = True
# Format the result for Claude
tool_results.append(
{
"type": "text",
"text": f"Collate all relevant information from your research so far to answer the target query:\n{query}.",
}
)
elif iterations == max_iterations and not run_summarize:
iterations = max_iterations - 1
run_summarize = True
task_completed = not tool_results
tool_results.append(
{
"type": "text",
"text": f"Collate all relevant information from your research so far to answer the target query:\n{query}.",
"type": "tool_result",
"tool_use_id": block.id,
"content": content,
"is_error": "error" in result,
}
)
# Mark the final tool result as a cache break point
if tool_results:
tool_results[-1]["cache_control"] = {"type": "ephemeral"}
# Calculate cost of chat
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
cache_read_tokens = response.usage.cache_read_input_tokens
cache_write_tokens = response.usage.cache_creation_input_tokens
tracer["usage"] = get_chat_usage_metrics(
chat_model.name,
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
usage=tracer.get("usage"),
)
logger.debug(f"Operator usage by {chat_model.model_type}: {tracer['usage']}")
# Remove all previous cache break point
for message in messages:
if isinstance(message["content"], list):
for tool_result in message["content"]:
if isinstance(tool_result, dict) and "cache_control" in tool_result:
del tool_result["cache_control"]
elif isinstance(message["content"], dict) and "cache_control" in message["content"]:
del message["content"]["cache_control"]
# Save conversation trace
tracer["chat_model"] = chat_model.name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(compiled_operator_messages[:-1], compiled_operator_messages[-1].content, tracer)
# Add tool results to messages for the next iteration with Claude
messages.append({"role": "user", "content": tool_results})
return (
response,
compiled_response,
tool_results,
safety_check,
)
if task_completed:
final_response = compiled_response
else:
final_response = f"Operator hit iteration limit. If the results seems incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{compiled_response}"
yield (final_response, safety_check)
async def browser_use_openai(*args, **kwargs):
"""
Deprecated: Use operate_browser directly.
This function is kept for potential backward compatibility checks but should be removed.
"""
logger.warning("browser_use_openai is deprecated. Use operate_browser instead.")
# The logic is now within operate_browser and _openai_iteration
raise NotImplementedError("browser_use_openai is deprecated.")
async def browser_use_anthropic(*args, **kwargs):
"""
Deprecated: Use operate_browser directly.
This function is kept for potential backward compatibility checks but should be removed.
"""
logger.warning("browser_use_anthropic is deprecated. Use operate_browser instead.")
# The logic is now within operate_browser and _anthropic_iteration
raise NotImplementedError("browser_use_anthropic is deprecated.")
# Mapping of CUA keys to Playwright keys