Pull out common iteration loop into main browser operator method

This commit is contained in:
Debanjum
2025-05-03 15:24:39 -06:00
parent 08e93c64ab
commit 7c60e04efb

View File

@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
async def operate_browser( async def operate_browser(
message: str, query: str,
user: KhojUser, user: KhojUser,
conversation_log: dict, conversation_log: dict,
location_data: LocationData, location_data: LocationData,
@@ -56,7 +56,7 @@ async def operate_browser(
) )
if send_status_func: if send_status_func:
async for event in send_status_func(f"**Launching Browser**:\n{message}"): async for event in send_status_func(f"**Launching Browser**:\n{query}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
# Start the browser # Start the browser
@@ -65,50 +65,136 @@ async def operate_browser(
# Operate the browser # Operate the browser
max_iterations = 40 max_iterations = 40
max_tokens = 4096
compiled_operator_messages: List[ChatMessage] = []
run_summarize = False
task_completed = False
iterations = 0
messages = [{"role": "user", "content": query}]
final_compiled_response = ""
error_msg_template = "Browser use with {model_type} model failed due to an {error_type} error: {e}"
with timer(f"Operating browser with {chat_model.model_type} {chat_model.name}", logger): with timer(f"Operating browser with {chat_model.model_type} {chat_model.name}", logger):
try: try:
while iterations < max_iterations:
# Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set():
logger.info(f"Browser operator cancelled by client disconnect")
break
iterations += 1
tool_results = []
compiled_response = ""
if chat_model.model_type == ChatModel.ModelType.OPENAI: if chat_model.model_type == ChatModel.ModelType.OPENAI:
async for result in browser_use_openai( (
message, agent_response,
chat_model, compiled_response,
page, tool_results,
width, safety_check_message,
height, ) = await _openai_iteration(
messages=messages,
chat_model=chat_model,
page=page,
width=width,
height=height,
max_tokens=max_tokens,
max_iterations=max_iterations, max_iterations=max_iterations,
send_status_func=send_status_func, compiled_operator_messages=compiled_operator_messages,
user=user,
agent=agent,
cancellation_event=cancellation_event,
tracer=tracer, tracer=tracer,
): )
if isinstance(result, dict) and ChatEvent.STATUS in result: messages += agent_response.output
yield result rendered_response = compiled_response
else: final_compiled_response = compiled_response
response, safety_check_message = result
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
async for result in browser_use_anthropic( (
message, agent_response,
chat_model, compiled_response,
page, tool_results,
width, safety_check_message,
height, ) = await _anthropic_iteration(
messages=messages,
chat_model=chat_model,
page=page,
width=width,
height=height,
max_tokens=max_tokens,
max_iterations=max_iterations, max_iterations=max_iterations,
send_status_func=send_status_func, compiled_operator_messages=compiled_operator_messages,
user=user,
agent=agent,
cancellation_event=cancellation_event,
tracer=tracer, tracer=tracer,
): )
if isinstance(result, dict) and ChatEvent.STATUS in result: messages.append({"role": "assistant", "content": agent_response.content})
yield result rendered_response = await render_claude_response(agent_response.content, page)
final_compiled_response = compiled_response
if send_status_func:
async for event in send_status_func(f"**Operating Browser**:\n{rendered_response}"):
yield {ChatEvent.STATUS: event}
# Check summarization conditions
summarize_prompt = (
f"Collate all relevant information from your research so far to answer the target query:\n{query}."
)
task_completed = not tool_results and not run_summarize
trigger_iteration_limit = iterations == max_iterations and not run_summarize
if task_completed or trigger_iteration_limit:
iterations = max_iterations - 1 # Ensure one more iteration for summarization
run_summarize = True
# Model specific handling for appending the summarize prompt
if chat_model.model_type == ChatModel.ModelType.OPENAI:
# Pop the last tool result if max iterations reached and agent attempted a tool call
any_tool_calls = any(
block.type in ["computer_call", "function_call"] for block in agent_response.output
)
if trigger_iteration_limit and any_tool_calls and tool_results:
tool_results.pop() # Remove the action that couldn't be processed due to limit
# Append summarize prompt
tool_results.append({"role": "user", "content": summarize_prompt})
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
# No specific action needed for Anthropic on iteration limit besides setting task_completed = False
# Append summarize prompt
tool_results.append({"type": "text", "text": summarize_prompt})
# Add tool results to messages for the next iteration
if tool_results:
if chat_model.model_type == ChatModel.ModelType.OPENAI:
messages += tool_results
elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC:
# Mark the final tool result as a cache break point
tool_results[-1]["cache_control"] = {"type": "ephemeral"}
# Remove all previous cache break points
for msg in messages:
if isinstance(msg["content"], list):
for tool_result in msg["content"]:
if isinstance(tool_result, dict) and "cache_control" in tool_result:
del tool_result["cache_control"]
elif isinstance(msg["content"], dict) and "cache_control" in msg["content"]:
del msg["content"]["cache_control"]
messages.append({"role": "user", "content": tool_results})
# Exit if safety checks are pending
if safety_check_message:
break
# Determine final response message
if task_completed:
response = final_compiled_response
else: else:
response, safety_check_message = result response = f"Operator hit iteration limit. If the results seems incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{final_compiled_response}"
except requests.RequestException as e: except requests.RequestException as e:
raise ValueError(f"Browser use with {chat_model.model_type} model failed due to a network error: {e}") error_msg = error_msg_template.format(model_type=chat_model.model_type, error_type="network", e=e)
raise ValueError(error_msg)
except Exception as e: except Exception as e:
raise ValueError(f"Browser use with {chat_model.model_type} model failed due to an unknown error: {e}") error_msg = error_msg_template.format(model_type=chat_model.model_type, error_type="unknown", e=e)
logger.exception(error_msg)
raise ValueError(error_msg)
finally: finally:
# Keep the browser open if safety checks are pending to allow user to investigate and resolve them
if not safety_check_message: if not safety_check_message:
# Close the browser # Close the browser
await browser.close() await browser.close()
@@ -152,33 +238,19 @@ async def start_browser(width: int = 1024, height: int = 768):
return playwright, browser, page return playwright, browser, page
async def browser_use_openai( async def _openai_iteration(
query: str, messages: list,
chat_model: ChatModel, chat_model: ChatModel,
page: Page, page: Page,
width: int = 1024, width: int,
height: int = 768, height: int,
max_tokens: int = 4096, max_tokens: int,
max_iterations: int = 40, max_iterations: int,
send_status_func: Optional[Callable] = None, compiled_operator_messages: List[ChatMessage],
user: KhojUser = None,
agent: Agent = None,
cancellation_event: Optional[asyncio.Event] = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """Performs one iteration of the OpenAI browser interaction loop."""
A simple agent loop for openai browser use interactions.
This function handles the back-and-forth between:
1. Sending user messages to Openai
2. Openai requesting to use browser use tools
3. Playwright executing those tools in browser.
4. Sending tool results back to Openai
"""
# Setup tools and API parameters
client = get_openai_async_client(chat_model.ai_model_api.api_key, chat_model.ai_model_api.api_base_url) client = get_openai_async_client(chat_model.ai_model_api.api_key, chat_model.ai_model_api.api_base_url)
messages = [{"role": "user", "content": query}]
safety_check_prefix = "The user needs to say 'continue' after resolving the following safety checks to proceed:" safety_check_prefix = "The user needs to say 'continue' after resolving the following safety checks to proceed:"
safety_check = None safety_check = None
system_prompt = f"""<SYSTEM_CAPABILITY> system_prompt = f"""<SYSTEM_CAPABILITY>
@@ -227,21 +299,6 @@ async def browser_use_openai(
}, },
] ]
# Main agent loop (with iteration limit to prevent runaway API costs)
compiled_operator_messages: List[ChatMessage] = []
run_summarize = False
task_completed = False
last_call_id = None
iterations = 0
while iterations < max_iterations:
# Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set():
logger.info(f"Browser operator cancelled by client disconnect")
break
iterations += 1
# Send the screenshot back as a computer_call_output
response = await client.responses.create( response = await client.responses.create(
model="computer-use-preview", model="computer-use-preview",
input=messages, input=messages,
@@ -253,23 +310,20 @@ async def browser_use_openai(
) )
logger.debug(f"Openai response: {response.model_dump_json()}") logger.debug(f"Openai response: {response.model_dump_json()}")
messages += response.output
compiled_response = compile_openai_response(response.output) compiled_response = compile_openai_response(response.output)
# Add Openai's response to the tracer conversation history # Add Openai's response to the tracer conversation history
compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response)) compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response))
if send_status_func:
async for event in send_status_func(f"**Operating Browser**:\n{compiled_response}"):
yield {ChatEvent.STATUS: event}
# Check if any tools used # Check if any tools used
tool_call_blocks = [ tool_call_blocks = [
block for block in response.output if block.type == "computer_call" or block.type == "function_call" block for block in response.output if block.type == "computer_call" or block.type == "function_call"
] ]
tool_results: list[dict[str, str | dict]] = [] tool_results: list[dict[str, str | dict]] = []
block_input: ActionBack | ActionGoto | response_computer_tool_call.Action = None last_call_id = None
# Run the tool calls in order # Run the tool calls in order
for block in tool_call_blocks: for block in tool_call_blocks:
block_input: ActionBack | ActionGoto | response_computer_tool_call.Action = None
if block.type == "function_call": if block.type == "function_call":
last_call_id = block.call_id last_call_id = block.call_id
if hasattr(block, "name") and block.name == "goto": if hasattr(block, "name") and block.name == "goto":
@@ -277,7 +331,7 @@ async def browser_use_openai(
block_input = ActionGoto(type="goto", url=url) block_input = ActionGoto(type="goto", url=url)
elif hasattr(block, "name") and block.name == "back": elif hasattr(block, "name") and block.name == "back":
block_input = ActionBack(type="back") block_input = ActionBack(type="back")
# if user doesn't ack all safety checks exit with error # Exit tool processing if safety check needed
elif block.type == "computer_call" and block.pending_safety_checks: elif block.type == "computer_call" and block.pending_safety_checks:
for check in block.pending_safety_checks: for check in block.pending_safety_checks:
if safety_check: if safety_check:
@@ -313,9 +367,7 @@ async def browser_use_openai(
# Calculate cost of chat # Calculate cost of chat
input_tokens = response.usage.input_tokens input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens output_tokens = response.usage.output_tokens
tracer["usage"] = get_chat_usage_metrics( tracer["usage"] = get_chat_usage_metrics(chat_model.name, input_tokens, output_tokens, usage=tracer.get("usage"))
chat_model.name, input_tokens, output_tokens, usage=tracer.get("usage")
)
logger.debug(f"Operator usage by Openai: {tracer['usage']}") logger.debug(f"Operator usage by Openai: {tracer['usage']}")
# Save conversation trace # Save conversation trace
@@ -323,74 +375,27 @@ async def browser_use_openai(
if is_promptrace_enabled(): if is_promptrace_enabled():
commit_conversation_trace(compiled_operator_messages[:-1], compiled_operator_messages[-1].content, tracer) commit_conversation_trace(compiled_operator_messages[:-1], compiled_operator_messages[-1].content, tracer)
# Run one last iteration to collate results of browser use, if no tool use requested or iteration limit reached. return response, compiled_response, tool_results, safety_check
if not tool_results and not run_summarize:
iterations = max_iterations - 1
run_summarize = True
task_completed = True
tool_results.append(
{
"type": "message",
"role": "user",
"content": f"Collate all relevant information from your research so far to answer the target query:\n{query}.",
}
)
elif iterations == max_iterations and not run_summarize:
iterations = max_iterations - 1
run_summarize = True
task_completed = not tool_results
# Pop the last tool call if max iterations reached
if tool_call_blocks:
tool_results.pop()
tool_results.append(
{
"type": "message",
"role": "user",
"content": f"Collate all relevant information from your research so far to answer the target query:\n{query}.",
}
)
# Add tool results to messages for the next iteration with Openai
messages += tool_results
if task_completed:
final_response = compiled_response
else:
final_response = f"Operator hit iteration limit. If the results seems incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{compiled_response}"
yield (final_response, safety_check)
async def browser_use_anthropic( async def _anthropic_iteration(
query: str, messages: list,
chat_model: ChatModel, chat_model: ChatModel,
page: Page, page: Page,
width: int = 1024, width: int,
height: int = 768, height: int,
max_tokens: int = 4096, max_tokens: int,
max_iterations: int,
compiled_operator_messages: List[ChatMessage],
thinking_budget: int | None = 1024, thinking_budget: int | None = 1024,
max_iterations: int = 40, # Add iteration limit to prevent infinite loops
send_status_func: Optional[Callable] = None,
user: KhojUser = None,
agent: Agent = None,
cancellation_event: Optional[asyncio.Event] = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """Performs one iteration of the Anthropic browser interaction loop."""
A simple agent loop for Claude computer use interactions.
This function handles the back-and-forth between:
1. Sending user messages to Claude
2. Claude requesting to use browser use tools
3. Playwright executing those tools in browser.
4. Sending tool results back to Claude
"""
# Set up tools and API parameters
client = get_anthropic_async_client(chat_model.ai_model_api.api_key, chat_model.ai_model_api.api_base_url) client = get_anthropic_async_client(chat_model.ai_model_api.api_key, chat_model.ai_model_api.api_base_url)
messages = [{"role": "user", "content": query}]
tool_version = "2025-01-24" tool_version = "2025-01-24"
betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"] betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"]
temperature = 1.0 temperature = 1.0
safety_check = None safety_check = None # Anthropic doesn't have explicit safety checks in the same way yet
system_prompt = f"""<SYSTEM_CAPABILITY> system_prompt = f"""<SYSTEM_CAPABILITY>
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser. * You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
* You operate a Chromium browser using Playwright. * You operate a Chromium browser using Playwright.
@@ -439,22 +444,7 @@ async def browser_use_anthropic(
# {"type": f"bash_20250124", "name": "bash"} # {"type": f"bash_20250124", "name": "bash"}
] ]
if agent: # Set up optional thinking parameter
agent.chat_model = chat_model
# Main agent loop (with iteration limit to prevent runaway API costs)
compiled_operator_messages: List[ChatMessage] = []
run_summarize = False
task_completed = False
iterations = 0
while iterations < max_iterations:
# Check for cancellation at the start of each iteration
if cancellation_event and cancellation_event.is_set():
logger.info(f"Browser operator cancelled by client disconnect")
break
iterations += 1
# Set up optional thinking parameter (for Claude 3.7 Sonnet)
thinking = {"type": "disabled"} thinking = {"type": "disabled"}
if chat_model.name.startswith("claude-3-7") and thinking_budget: if chat_model.name.startswith("claude-3-7") and thinking_budget:
thinking = {"type": "enabled", "budget_tokens": thinking_budget} thinking = {"type": "enabled", "budget_tokens": thinking_budget}
@@ -475,13 +465,8 @@ async def browser_use_anthropic(
compiled_response = compile_claude_response(response_content) compiled_response = compile_claude_response(response_content)
# Add Claude's response to the conversation history # Add Claude's response to the conversation history
messages.append({"role": "assistant", "content": response_content})
compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response)) compiled_operator_messages.append(ChatMessage(role="assistant", content=compiled_response))
logger.debug(f"Claude response: {response.model_dump_json()}") logger.debug(f"Claude response: {response.model_dump_json()}")
if send_status_func:
rendered_response = await render_claude_response(response_content, page)
async for event in send_status_func(f"**Operating Browser**:\n{rendered_response}"):
yield {ChatEvent.STATUS: event}
# Check if Claude used any tools # Check if Claude used any tools
tool_results = [] tool_results = []
@@ -497,11 +482,8 @@ async def browser_use_anthropic(
content = result.get("output") or result.get("error") content = result.get("output") or result.get("error")
if isinstance(content, str): if isinstance(content, str):
compiled_operator_messages.append(ChatMessage(role="browser", content=content)) compiled_operator_messages.append(ChatMessage(role="browser", content=content))
elif isinstance(content, list) and content[0]["type"] == "image": elif isinstance(content, list) and content and content[0]["type"] == "image":
# Handle the case where the content is an image compiled_operator_messages.append(ChatMessage(role="browser", content="[placeholder for screenshot]"))
compiled_operator_messages.append(
ChatMessage(role="browser", content="[placeholder for screenshot]")
)
# Format the result for Claude # Format the result for Claude
tool_results.append( tool_results.append(
@@ -526,7 +508,7 @@ async def browser_use_anthropic(
cache_write_tokens, cache_write_tokens,
usage=tracer.get("usage"), usage=tracer.get("usage"),
) )
logger.debug(f"Operator usage by Claude: {tracer['usage']}") logger.debug(f"Operator usage by {chat_model.model_type}: {tracer['usage']}")
# Save conversation trace # Save conversation trace
tracer["chat_model"] = chat_model.name tracer["chat_model"] = chat_model.name
@@ -534,49 +516,32 @@ async def browser_use_anthropic(
if is_promptrace_enabled(): if is_promptrace_enabled():
commit_conversation_trace(compiled_operator_messages[:-1], compiled_operator_messages[-1].content, tracer) commit_conversation_trace(compiled_operator_messages[:-1], compiled_operator_messages[-1].content, tracer)
# Run one last iteration to collate results of browser use, if no tool use requested or iteration limit reached. return (
if not tool_results and not run_summarize: response,
iterations = max_iterations - 1 compiled_response,
run_summarize = True tool_results,
task_completed = True safety_check,
tool_results.append(
{
"type": "text",
"text": f"Collate all relevant information from your research so far to answer the target query:\n{query}.",
}
)
elif iterations == max_iterations and not run_summarize:
iterations = max_iterations - 1
run_summarize = True
task_completed = not tool_results
tool_results.append(
{
"type": "text",
"text": f"Collate all relevant information from your research so far to answer the target query:\n{query}.",
}
) )
# Mark the final tool result as a cache break point
if tool_results:
tool_results[-1]["cache_control"] = {"type": "ephemeral"}
# Remove all previous cache break point async def browser_use_openai(*args, **kwargs):
for message in messages: """
if isinstance(message["content"], list): Deprecated: Use operate_browser directly.
for tool_result in message["content"]: This function is kept for potential backward compatibility checks but should be removed.
if isinstance(tool_result, dict) and "cache_control" in tool_result: """
del tool_result["cache_control"] logger.warning("browser_use_openai is deprecated. Use operate_browser instead.")
elif isinstance(message["content"], dict) and "cache_control" in message["content"]: # The logic is now within operate_browser and _openai_iteration
del message["content"]["cache_control"] raise NotImplementedError("browser_use_openai is deprecated.")
# Add tool results to messages for the next iteration with Claude
messages.append({"role": "user", "content": tool_results})
if task_completed: async def browser_use_anthropic(*args, **kwargs):
final_response = compiled_response """
else: Deprecated: Use operate_browser directly.
final_response = f"Operator hit iteration limit. If the results seems incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{compiled_response}" This function is kept for potential backward compatibility checks but should be removed.
yield (final_response, safety_check) """
logger.warning("browser_use_anthropic is deprecated. Use operate_browser instead.")
# The logic is now within operate_browser and _anthropic_iteration
raise NotImplementedError("browser_use_anthropic is deprecated.")
# Mapping of CUA keys to Playwright keys # Mapping of CUA keys to Playwright keys