Pull out common iteration loop into main browser operator method

This commit is contained in:
Debanjum
2025-05-03 15:24:39 -06:00
parent 08e93c64ab
commit 7c60e04efb

View File

@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
async def operate_browser( async def operate_browser(
message: str, query: str,
user: KhojUser, user: KhojUser,
conversation_log: dict, conversation_log: dict,
location_data: LocationData, location_data: LocationData,
@@ -56,7 +56,7 @@ async def operate_browser(
) )
if send_status_func: if send_status_func:
async for event in send_status_func(f"**Launching Browser**:\n{message}"): async for event in send_status_func(f"**Launching Browser**:\n{query}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
# Start the browser # Start the browser
@@ -65,50 +65,136 @@ async def operate_browser(
# Operate the browser # Operate the browser
max_iterations = 40 max_iterations = 40
max_tokens = 4096
compiled_operator_messages: List[ChatMessage] = []
run_summarize = False
task_completed = False
iterations = 0
messages = [{"role": "user", "content": query}]
final_compiled_response = ""
error_msg_template = "Browser use with {model_type} model failed due to an {error_type} error: {e}"
with timer(f"Operating browser with {chat_model.model_type} {chat_model.name}", logger): with timer(f"Operating browser with {chat_model.model_type} {chat_model.name}", logger):
try: try:
if chat_model.model_type == ChatModel.ModelType.OPENAI: while iterations < max_iterations:
async for result in browser_use_openai( # Check for cancellation at the start of each iteration
message, if cancellation_event and cancellation_event.is_set():
chat_model, logger.info(f"Browser operator cancelled by client disconnect")
page, break
width,
height, iterations += 1
max_iterations=max_iterations, tool_results = []
send_status_func=send_status_func, compiled_response = ""
user=user,
agent=agent, if chat_model.model_type == ChatModel.ModelType.OPENAI:
cancellation_event=cancellation_event, (
tracer=tracer, agent_response,
): compiled_response,
if isinstance(result, dict) and ChatEvent.STATUS in result: tool_results,
yield result safety_check_message,
else: ) = await _openai_iteration(
response, safety_check_message = result messages=messages,
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: chat_model=chat_model,
async for result in browser_use_anthropic( page=page,
message, width=width,
chat_model, height=height,
page, max_tokens=max_tokens,
width, max_iterations=max_iterations,
height, compiled_operator_messages=compiled_operator_messages,
max_iterations=max_iterations, tracer=tracer,
send_status_func=send_status_func, )
user=user, messages += agent_response.output
agent=agent, rendered_response = compiled_response
cancellation_event=cancellation_event, final_compiled_response = compiled_response
tracer=tracer,
): elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
if isinstance(result, dict) and ChatEvent.STATUS in result: (
yield result agent_response,
else: compiled_response,
response, safety_check_message = result tool_results,
safety_check_message,
) = await _anthropic_iteration(
messages=messages,
chat_model=chat_model,
page=page,
width=width,
height=height,
max_tokens=max_tokens,
max_iterations=max_iterations,
compiled_operator_messages=compiled_operator_messages,
tracer=tracer,
)
messages.append({"role": "assistant", "content": agent_response.content})
rendered_response = await render_claude_response(agent_response.content, page)
final_compiled_response = compiled_response
if send_status_func:
async for event in send_status_func(f"**Operating Browser**:\n{rendered_response}"):
yield {ChatEvent.STATUS: event}
# Check summarization conditions
summarize_prompt = (
f"Collate all relevant information from your research so far to answer the target query:\n{query}."
)
task_completed = not tool_results and not run_summarize
trigger_iteration_limit = iterations == max_iterations and not run_summarize
if task_completed or trigger_iteration_limit:
iterations = max_iterations - 1 # Ensure one more iteration for summarization
run_summarize = True
# Model specific handling for appending the summarize prompt
if chat_model.model_type == ChatModel.ModelType.OPENAI:
# Pop the last tool result if max iterations reached and agent attempted a tool call
any_tool_calls = any(
block.type in ["computer_call", "function_call"] for block in agent_response.output
)
if trigger_iteration_limit and any_tool_calls and tool_results:
tool_results.pop() # Remove the action that couldn't be processed due to limit
# Append summarize prompt
tool_results.append({"role": "user", "content": summarize_prompt})
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
# No specific action needed for Anthropic on iteration limit besides setting task_completed = False
# Append summarize prompt
tool_results.append({"type": "text", "text": summarize_prompt})
# Add tool results to messages for the next iteration
if tool_results:
if chat_model.model_type == ChatModel.ModelType.OPENAI:
messages += tool_results
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
# Mark the final tool result as a cache break point
tool_results[-1]["cache_control"] = {"type": "ephemeral"}
# Remove all previous cache break points
for msg in messages:
if isinstance(msg["content"], list):
for tool_result in msg["content"]:
if isinstance(tool_result, dict) and "cache_control" in tool_result:
del tool_result["cache_control"]
elif isinstance(msg["content"], dict) and "cache_control" in msg["content"]:
del msg["content"]["cache_control"]
messages.append({"role": "user", "content": tool_results})
# Exit if safety checks are pending
if safety_check_message:
break
# Determine final response message
if task_completed:
response = final_compiled_response
else:
response = f"Operator hit iteration limit. If the results seems incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{final_compiled_response}"
except requests.RequestException as e: except requests.RequestException as e:
raise ValueError(f"Browser use with {chat_model.model_type} model failed due to a network error: {e}") error_msg = error_msg_template.format(model_type=chat_model.model_type, error_type="network", e=e)
raise ValueError(error_msg)
except Exception as e: except Exception as e:
raise ValueError(f"Browser use with {chat_model.model_type} model failed due to an unknown error: {e}") error_msg = error_msg_template.format(model_type=chat_model.model_type, error_type="unknown", e=e)
logger.exception(error_msg)
raise ValueError(error_msg)
finally: finally:
# Keep the browser open if safety checks are pending to allow user to investigate and resolve them
if not safety_check_message: if not safety_check_message:
# Close the browser # Close the browser
await browser.close() await browser.close()
@@ -152,33 +238,19 @@ async def start_browser(width: int = 1024, height: int = 768):
return playwright, browser, page return playwright, browser, page
async def browser_use_openai( async def _openai_iteration(
query: str, messages: list,
chat_model: ChatModel, chat_model: ChatModel,
page: Page, page: Page,
width: int = 1024, width: int,
height: int = 768, height: int,
max_tokens: int = 4096, max_tokens: int,
max_iterations: int = 40, max_iterations: int,
send_status_func: Optional[Callable] = None, compiled_operator_messages: List[ChatMessage],
user: KhojUser = None,
agent: Agent = None,
cancellation_event: Optional[asyncio.Event] = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """Performs one iteration of the OpenAI browser interaction loop."""
A simple agent loop for openai browser use interactions.
This function handles the back-and-forth between:
1. Sending user messages to Openai
2. Openai requesting to use browser use tools
3. Playwright executing those tools in browser.
4. Sending tool results back to Openai
"""
# Setup tools and API parameters
client = get_openai_async_client(chat_model.ai_model_api.api_key, chat_model.ai_model_api.api_base_url) client = get_openai_async_client(chat_model.ai_model_api.api_key, chat_model.ai_model_api.api_base_url)
messages = [{"role": "user", "content": query}]
safety_check_prefix = "The user needs to say 'continue' after resolving the following safety checks to proceed:" safety_check_prefix = "The user needs to say 'continue' after resolving the following safety checks to proceed:"
safety_check = None safety_check = None
system_prompt = f"""<SYSTEM_CAPABILITY> system_prompt = f"""<SYSTEM_CAPABILITY>
@@ -227,170 +299,103 @@ async def browser_use_openai(
}, },
] ]
# Main agent loop (with iteration limit to prevent runaway API costs) response = await client.responses.create(
compiled_operator_messages: List[ChatMessage] = [] model="computer-use-preview",
run_summarize = False input=messages,
task_completed = False instructions=system_prompt,
tools=tools,
parallel_tool_calls=False,
max_output_tokens=max_tokens,
truncation="auto",
)
logger.debug(f"Openai response: {response.model_dump_json()}")
compiled_response = compile_openai_response(response.output)
# Add Openai's response to the tracer conversation history
compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response))
# Check if any tools used
tool_call_blocks = [
block for block in response.output if block.type == "computer_call" or block.type == "function_call"
]
tool_results: list[dict[str, str | dict]] = []
last_call_id = None last_call_id = None
iterations = 0 # Run the tool calls in order
while iterations < max_iterations: for block in tool_call_blocks:
# Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set():
logger.info(f"Browser operator cancelled by client disconnect")
break
iterations += 1
# Send the screenshot back as a computer_call_output
response = await client.responses.create(
model="computer-use-preview",
input=messages,
instructions=system_prompt,
tools=tools,
parallel_tool_calls=False,
max_output_tokens=max_tokens,
truncation="auto",
)
logger.debug(f"Openai response: {response.model_dump_json()}")
messages += response.output
compiled_response = compile_openai_response(response.output)
# Add Openai's response to the tracer conversation history
compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response))
if send_status_func:
async for event in send_status_func(f"**Operating Browser**:\n{compiled_response}"):
yield {ChatEvent.STATUS: event}
# Check if any tools used
tool_call_blocks = [
block for block in response.output if block.type == "computer_call" or block.type == "function_call"
]
tool_results: list[dict[str, str | dict]] = []
block_input: ActionBack | ActionGoto | response_computer_tool_call.Action = None block_input: ActionBack | ActionGoto | response_computer_tool_call.Action = None
# Run the tool calls in order if block.type == "function_call":
for block in tool_call_blocks: last_call_id = block.call_id
if block.type == "function_call": if hasattr(block, "name") and block.name == "goto":
last_call_id = block.call_id url = json.loads(block.arguments).get("url")
if hasattr(block, "name") and block.name == "goto": block_input = ActionGoto(type="goto", url=url)
url = json.loads(block.arguments).get("url") elif hasattr(block, "name") and block.name == "back":
block_input = ActionGoto(type="goto", url=url) block_input = ActionBack(type="back")
elif hasattr(block, "name") and block.name == "back": # Exit tool processing if safety check needed
block_input = ActionBack(type="back") elif block.type == "computer_call" and block.pending_safety_checks:
# if user doesn't ack all safety checks exit with error for check in block.pending_safety_checks:
elif block.type == "computer_call" and block.pending_safety_checks: if safety_check:
for check in block.pending_safety_checks: safety_check += f"\n- {check.message}"
if safety_check: else:
safety_check += f"\n- {check.message}" safety_check = f"{safety_check_prefix}\n- {check.message}"
else: break
safety_check = f"{safety_check_prefix}\n- {check.message}" elif block.type == "computer_call":
break last_call_id = block.call_id
elif block.type == "computer_call": block_input = block.action
last_call_id = block.call_id
block_input = block.action
result = await handle_browser_action_openai(page, block_input) result = await handle_browser_action_openai(page, block_input)
content_text = result.get("output") or result.get("error") content_text = result.get("output") or result.get("error")
compiled_operator_messages.append(ChatMessage(role="browser", content=content_text)) compiled_operator_messages.append(ChatMessage(role="browser", content=content_text))
# Take a screenshot after computer action # Take a screenshot after computer action
if block.type == "computer_call": if block.type == "computer_call":
screenshot_base64 = await get_screenshot(page) screenshot_base64 = await get_screenshot(page)
content = {"type": "input_image", "image_url": f"data:image/webp;base64,{screenshot_base64}"} content = {"type": "input_image", "image_url": f"data:image/webp;base64,{screenshot_base64}"}
content["current_url"] = page.url if block.type == "computer_call" else None content["current_url"] = page.url if block.type == "computer_call" else None
elif block.type == "function_call": elif block.type == "function_call":
content = content_text content = content_text
# Format the tool call results # Format the tool call results
tool_results.append( tool_results.append(
{ {
"type": f"{block.type}_output", "type": f"{block.type}_output",
"output": content, "output": content,
"call_id": last_call_id, "call_id": last_call_id,
} }
)
# Calculate cost of chat
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
tracer["usage"] = get_chat_usage_metrics(
chat_model.name, input_tokens, output_tokens, usage=tracer.get("usage")
) )
logger.debug(f"Operator usage by Openai: {tracer['usage']}")
# Save conversation trace # Calculate cost of chat
tracer["chat_model"] = chat_model.name input_tokens = response.usage.input_tokens
if is_promptrace_enabled(): output_tokens = response.usage.output_tokens
commit_conversation_trace(compiled_operator_messages[:-1], compiled_operator_messages[-1].content, tracer) tracer["usage"] = get_chat_usage_metrics(chat_model.name, input_tokens, output_tokens, usage=tracer.get("usage"))
logger.debug(f"Operator usage by Openai: {tracer['usage']}")
# Run one last iteration to collate results of browser use, if no tool use requested or iteration limit reached. # Save conversation trace
if not tool_results and not run_summarize: tracer["chat_model"] = chat_model.name
iterations = max_iterations - 1 if is_promptrace_enabled():
run_summarize = True commit_conversation_trace(compiled_operator_messages[:-1], compiled_operator_messages[-1].content, tracer)
task_completed = True
tool_results.append(
{
"type": "message",
"role": "user",
"content": f"Collate all relevant information from your research so far to answer the target query:\n{query}.",
}
)
elif iterations == max_iterations and not run_summarize:
iterations = max_iterations - 1
run_summarize = True
task_completed = not tool_results
# Pop the last tool call if max iterations reached
if tool_call_blocks:
tool_results.pop()
tool_results.append(
{
"type": "message",
"role": "user",
"content": f"Collate all relevant information from your research so far to answer the target query:\n{query}.",
}
)
# Add tool results to messages for the next iteration with Openai return response, compiled_response, tool_results, safety_check
messages += tool_results
if task_completed:
final_response = compiled_response
else:
final_response = f"Operator hit iteration limit. If the results seems incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{compiled_response}"
yield (final_response, safety_check)
async def browser_use_anthropic( async def _anthropic_iteration(
query: str, messages: list,
chat_model: ChatModel, chat_model: ChatModel,
page: Page, page: Page,
width: int = 1024, width: int,
height: int = 768, height: int,
max_tokens: int = 4096, max_tokens: int,
max_iterations: int,
compiled_operator_messages: List[ChatMessage],
thinking_budget: int | None = 1024, thinking_budget: int | None = 1024,
max_iterations: int = 40, # Add iteration limit to prevent infinite loops
send_status_func: Optional[Callable] = None,
user: KhojUser = None,
agent: Agent = None,
cancellation_event: Optional[asyncio.Event] = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """Performs one iteration of the Anthropic browser interaction loop."""
A simple agent loop for Claude computer use interactions.
This function handles the back-and-forth between:
1. Sending user messages to Claude
2. Claude requesting to use browser use tools
3. Playwright executing those tools in browser.
4. Sending tool results back to Claude
"""
# Set up tools and API parameters
client = get_anthropic_async_client(chat_model.ai_model_api.api_key, chat_model.ai_model_api.api_base_url) client = get_anthropic_async_client(chat_model.ai_model_api.api_key, chat_model.ai_model_api.api_base_url)
messages = [{"role": "user", "content": query}]
tool_version = "2025-01-24" tool_version = "2025-01-24"
betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"] betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"]
temperature = 1.0 temperature = 1.0
safety_check = None safety_check = None # Anthropic doesn't have explicit safety checks in the same way yet
system_prompt = f"""<SYSTEM_CAPABILITY> system_prompt = f"""<SYSTEM_CAPABILITY>
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser. * You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
* You operate a Chromium browser using Playwright. * You operate a Chromium browser using Playwright.
@@ -439,144 +444,104 @@ async def browser_use_anthropic(
# {"type": f"bash_20250124", "name": "bash"} # {"type": f"bash_20250124", "name": "bash"}
] ]
if agent: # Set up optional thinking parameter
agent.chat_model = chat_model thinking = {"type": "disabled"}
if chat_model.name.startswith("claude-3-7") and thinking_budget:
thinking = {"type": "enabled", "budget_tokens": thinking_budget}
# Main agent loop (with iteration limit to prevent runaway API costs) # Call the Claude API
compiled_operator_messages: List[ChatMessage] = [] response = await client.beta.messages.create(
run_summarize = False messages=messages,
task_completed = False model=chat_model.name,
iterations = 0 system=system_prompt,
while iterations < max_iterations: tools=tools,
# Check for cancellation at the start of each iteration betas=betas,
if cancellation_event and cancellation_event.is_set(): thinking=thinking,
logger.info(f"Browser operator cancelled by client disconnect") max_tokens=max_tokens,
break temperature=temperature,
)
iterations += 1 response_content = response.content
# Set up optional thinking parameter (for Claude 3.7 Sonnet) compiled_response = compile_claude_response(response_content)
thinking = {"type": "disabled"}
if chat_model.name.startswith("claude-3-7") and thinking_budget:
thinking = {"type": "enabled", "budget_tokens": thinking_budget}
# Call the Claude API # Add Claude's response to the conversation history
response = await client.beta.messages.create( compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response))
messages=messages, logger.debug(f"Claude response: {response.model_dump_json()}")
model=chat_model.name,
system=system_prompt,
tools=tools,
betas=betas,
thinking=thinking,
max_tokens=max_tokens,
temperature=temperature,
)
response_content = response.content # Check if Claude used any tools
compiled_response = compile_claude_response(response_content) tool_results = []
for block in response_content:
if block.type == "tool_use":
if hasattr(block, "name") and block.name == "goto":
block_input = {"action": block.name, "url": block.input.get("url")}
elif hasattr(block, "name") and block.name == "back":
block_input = {"action": block.name}
else:
block_input = block.input
result = await handle_browser_action_anthropic(page, block_input)
content = result.get("output") or result.get("error")
if isinstance(content, str):
compiled_operator_messages.append(ChatMessage(role="browser", content=content))
elif isinstance(content, list) and content and content[0]["type"] == "image":
compiled_operator_messages.append(ChatMessage(role="browser", content="[placeholder for screenshot]"))
# Add Claude's response to the conversation history # Format the result for Claude
messages.append({"role": "assistant", "content": response_content})
compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response))
logger.debug(f"Claude response: {response.model_dump_json()}")
if send_status_func:
rendered_response = await render_claude_response(response_content, page)
async for event in send_status_func(f"**Operating Browser**:\n{rendered_response}"):
yield {ChatEvent.STATUS: event}
# Check if Claude used any tools
tool_results = []
for block in response_content:
if block.type == "tool_use":
if hasattr(block, "name") and block.name == "goto":
block_input = {"action": block.name, "url": block.input.get("url")}
elif hasattr(block, "name") and block.name == "back":
block_input = {"action": block.name}
else:
block_input = block.input
result = await handle_browser_action_anthropic(page, block_input)
content = result.get("output") or result.get("error")
if isinstance(content, str):
compiled_operator_messages.append(ChatMessage(role="browser", content=content))
elif isinstance(content, list) and content[0]["type"] == "image":
# Handle the case where the content is an image
compiled_operator_messages.append(
ChatMessage(role="browser", content="[placeholder for screenshot]")
)
# Format the result for Claude
tool_results.append(
{
"type": "tool_result",
"tool_use_id": block.id,
"content": content,
"is_error": "error" in result,
}
)
# Calculate cost of chat
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
cache_read_tokens = response.usage.cache_read_input_tokens
cache_write_tokens = response.usage.cache_creation_input_tokens
tracer["usage"] = get_chat_usage_metrics(
chat_model.name,
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
usage=tracer.get("usage"),
)
logger.debug(f"Operator usage by Claude: {tracer['usage']}")
# Save conversation trace
tracer["chat_model"] = chat_model.name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(compiled_operator_messages[:-1], compiled_operator_messages[-1].content, tracer)
# Run one last iteration to collate results of browser use, if no tool use requested or iteration limit reached.
if not tool_results and not run_summarize:
iterations = max_iterations - 1
run_summarize = True
task_completed = True
tool_results.append( tool_results.append(
{ {
"type": "text", "type": "tool_result",
"text": f"Collate all relevant information from your research so far to answer the target query:\n{query}.", "tool_use_id": block.id,
} "content": content,
) "is_error": "error" in result,
elif iterations == max_iterations and not run_summarize:
iterations = max_iterations - 1
run_summarize = True
task_completed = not tool_results
tool_results.append(
{
"type": "text",
"text": f"Collate all relevant information from your research so far to answer the target query:\n{query}.",
} }
) )
# Mark the final tool result as a cache break point # Calculate cost of chat
if tool_results: input_tokens = response.usage.input_tokens
tool_results[-1]["cache_control"] = {"type": "ephemeral"} output_tokens = response.usage.output_tokens
cache_read_tokens = response.usage.cache_read_input_tokens
cache_write_tokens = response.usage.cache_creation_input_tokens
tracer["usage"] = get_chat_usage_metrics(
chat_model.name,
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
usage=tracer.get("usage"),
)
logger.debug(f"Operator usage by {chat_model.model_type}: {tracer['usage']}")
# Remove all previous cache break point # Save conversation trace
for message in messages: tracer["chat_model"] = chat_model.name
if isinstance(message["content"], list): tracer["temperature"] = temperature
for tool_result in message["content"]: if is_promptrace_enabled():
if isinstance(tool_result, dict) and "cache_control" in tool_result: commit_conversation_trace(compiled_operator_messages[:-1], compiled_operator_messages[-1].content, tracer)
del tool_result["cache_control"]
elif isinstance(message["content"], dict) and "cache_control" in message["content"]:
del message["content"]["cache_control"]
# Add tool results to messages for the next iteration with Claude return (
messages.append({"role": "user", "content": tool_results}) response,
compiled_response,
tool_results,
safety_check,
)
if task_completed:
final_response = compiled_response async def browser_use_openai(*args, **kwargs):
else: """
final_response = f"Operator hit iteration limit. If the results seems incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{compiled_response}" Deprecated: Use operate_browser directly.
yield (final_response, safety_check) This function is kept for potential backward compatibility checks but should be removed.
"""
logger.warning("browser_use_openai is deprecated. Use operate_browser instead.")
# The logic is now within operate_browser and _openai_iteration
raise NotImplementedError("browser_use_openai is deprecated.")
async def browser_use_anthropic(*args, **kwargs):
"""
Deprecated: Use operate_browser directly.
This function is kept for potential backward compatibility checks but should be removed.
"""
logger.warning("browser_use_anthropic is deprecated. Use operate_browser instead.")
# The logic is now within operate_browser and _anthropic_iteration
raise NotImplementedError("browser_use_anthropic is deprecated.")
# Mapping of CUA keys to Playwright keys # Mapping of CUA keys to Playwright keys