Generalize operator to operate multiple types of environment

Previously it could only operate a (playwright) browser. Now
- The operator logic and naming has been updated assuming
  multiple environment types can be operated
- The operator entrypoint is now at __init__.py to simplify imports
  and the entrypoint function is called operate_environment
- All operator agents have been updated to select their system prompts
  and tools based on the environment they'll operate
This commit is contained in:
Debanjum
2025-05-26 17:56:34 -07:00
parent c0689b2740
commit 7eab87bfdf
11 changed files with 639 additions and 395 deletions

View File

@@ -11,7 +11,10 @@ from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAg
from khoj.processor.operator.operator_agent_base import OperatorAgent from khoj.processor.operator.operator_agent_base import OperatorAgent
from khoj.processor.operator.operator_agent_binary import BinaryOperatorAgent from khoj.processor.operator.operator_agent_binary import BinaryOperatorAgent
from khoj.processor.operator.operator_agent_openai import OpenAIOperatorAgent from khoj.processor.operator.operator_agent_openai import OpenAIOperatorAgent
from khoj.processor.operator.operator_environment_base import EnvStepResult from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvStepResult,
)
from khoj.processor.operator.operator_environment_browser import BrowserEnvironment from khoj.processor.operator.operator_environment_browser import BrowserEnvironment
from khoj.routers.helpers import ChatEvent from khoj.routers.helpers import ChatEvent
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
@@ -20,12 +23,13 @@ from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --- Browser Operator Function --- # --- Main Operator Entrypoint ---
async def operate_browser( async def operate_environment(
query: str, query: str,
user: KhojUser, user: KhojUser,
conversation_log: dict, conversation_log: dict,
location_data: LocationData, location_data: LocationData,
environment_type: EnvironmentType = EnvironmentType.COMPUTER,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None, # TODO: Handle query images query_images: Optional[List[str]] = None, # TODO: Handle query images
agent: Agent = None, agent: Agent = None,
@@ -34,7 +38,6 @@ async def operate_browser(
tracer: dict = {}, tracer: dict = {},
): ):
response, summary_message, user_input_message = None, None, None response, summary_message, user_input_message = None, None, None
environment: Optional[BrowserEnvironment] = None
# Get the agent chat model # Get the agent chat model
agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None
@@ -42,15 +45,15 @@ async def operate_browser(
if not reasoning_model or not reasoning_model.vision_enabled: if not reasoning_model or not reasoning_model.vision_enabled:
reasoning_model = await ConversationAdapters.aget_vision_enabled_config() reasoning_model = await ConversationAdapters.aget_vision_enabled_config()
if not reasoning_model: if not reasoning_model:
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate browser.") raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
# Initialize Agent # Initialize Agent
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 40)) max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 40))
operator_agent: OperatorAgent operator_agent: OperatorAgent
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI: if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
operator_agent = OpenAIOperatorAgent(query, reasoning_model, max_iterations, tracer) operator_agent = OpenAIOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer)
elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC: elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
operator_agent = AnthropicOperatorAgent(query, reasoning_model, max_iterations, tracer) operator_agent = AnthropicOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer)
else: else:
grounding_model_name = "ui-tars-1.5" grounding_model_name = "ui-tars-1.5"
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name) grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
@@ -60,11 +63,13 @@ async def operate_browser(
or not grounding_model.model_type == ChatModel.ModelType.OPENAI or not grounding_model.model_type == ChatModel.ModelType.OPENAI
): ):
raise ValueError("No supported visual grounding model for binary operator agent found.") raise ValueError("No supported visual grounding model for binary operator agent found.")
operator_agent = BinaryOperatorAgent(query, reasoning_model, grounding_model, max_iterations, tracer) operator_agent = BinaryOperatorAgent(
query, reasoning_model, grounding_model, environment_type, max_iterations, tracer
)
# Initialize Environment # Initialize Environment
if send_status_func: if send_status_func:
async for event in send_status_func(f"**Launching Browser**"): async for event in send_status_func(f"**Launching {environment_type.value}**"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
environment = BrowserEnvironment() environment = BrowserEnvironment()
await environment.start(width=1024, height=768) await environment.start(width=1024, height=768)
@@ -75,25 +80,27 @@ async def operate_browser(
task_completed = False task_completed = False
iterations = 0 iterations = 0
with timer(f"Operating browser with {reasoning_model.model_type} {reasoning_model.name}", logger): with timer(
f"Operating {environment_type.value} with {reasoning_model.model_type} {reasoning_model.name}", logger
):
while iterations < max_iterations and not task_completed: while iterations < max_iterations and not task_completed:
if cancellation_event and cancellation_event.is_set(): if cancellation_event and cancellation_event.is_set():
logger.debug(f"Browser operator cancelled by client disconnect") logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
break break
iterations += 1 iterations += 1
# 1. Get current environment state # 1. Get current environment state
browser_state = await environment.get_state() env_state = await environment.get_state()
# 2. Agent decides action(s) # 2. Agent decides action(s)
agent_result = await operator_agent.act(browser_state) agent_result = await operator_agent.act(env_state)
# 3. Execute actions in the environment # 3. Execute actions in the environment
env_steps: List[EnvStepResult] = [] env_steps: List[EnvStepResult] = []
for action in agent_result.actions: for action in agent_result.actions:
if cancellation_event and cancellation_event.is_set(): if cancellation_event and cancellation_event.is_set():
logger.debug(f"Browser operator cancelled by client disconnect") logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
break break
# Handle request for user action and break the loop # Handle request for user action and break the loop
if isinstance(action, RequestUserAction): if isinstance(action, RequestUserAction):
@@ -106,12 +113,14 @@ async def operate_browser(
env_steps.append(env_step) env_steps.append(env_step)
# Render status update # Render status update
latest_screenshot = f"data:image/webp;base64,{env_steps[-1].screenshot_base64 if env_steps else browser_state.screenshot}" latest_screenshot = (
f"data:image/webp;base64,{env_steps[-1].screenshot_base64 if env_steps else env_state.screenshot}"
)
render_payload = agent_result.rendered_response render_payload = agent_result.rendered_response
render_payload["image"] = latest_screenshot render_payload["image"] = latest_screenshot
render_content = f"**Action**: {json.dumps(render_payload)}" render_content = f"**Action**: {json.dumps(render_payload)}"
if send_status_func: if send_status_func:
async for event in send_status_func(f"**Operating Browser**:\n{render_content}"): async for event in send_status_func(f"**Operating {environment_type.value}**:\n{render_content}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
# Check if termination conditions are met # Check if termination conditions are met
@@ -123,7 +132,7 @@ async def operate_browser(
if task_completed or trigger_iteration_limit: if task_completed or trigger_iteration_limit:
# Summarize results of operator run on last iteration # Summarize results of operator run on last iteration
operator_agent.add_action_results(env_steps, agent_result) operator_agent.add_action_results(env_steps, agent_result)
summary_message = await operator_agent.summarize(summarize_prompt, browser_state) summary_message = await operator_agent.summarize(summarize_prompt, env_state)
logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}") logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}")
break break
@@ -138,15 +147,19 @@ async def operate_browser(
else: # Hit iteration limit else: # Hit iteration limit
response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}" response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}"
finally: finally:
if environment and not user_input_message: # Don't close browser if user input required if environment and not user_input_message: # Don't close environment if user input required
await environment.close() await environment.close()
if operator_agent: if operator_agent:
operator_agent.reset() operator_agent.reset()
webpages = []
if environment_type == EnvironmentType.BROWSER and hasattr(environment, "visited_urls"):
webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls]
yield { yield {
"query": query, "query": query,
"result": user_input_message or response, "result": user_input_message or response,
"webpages": [{"link": url, "snippet": ""} for url in environment.visited_urls], "webpages": webpages,
} }

View File

@@ -1,5 +1,6 @@
import json import json
import logging import logging
from textwrap import dedent
from openai import AzureOpenAI, OpenAI from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat import ChatCompletion, ChatCompletionMessage
@@ -8,7 +9,7 @@ from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import construct_structured_message from khoj.processor.conversation.utils import construct_structured_message
from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import AgentActResult from khoj.processor.operator.operator_agent_base import AgentActResult
from khoj.processor.operator.operator_environment_base import EnvState from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
from khoj.utils.helpers import get_chat_usage_metrics from khoj.utils.helpers import get_chat_usage_metrics
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -18,6 +19,7 @@ class GroundingAgent:
def __init__( def __init__(
self, self,
model: ChatModel, model: ChatModel,
environment_type: EnvironmentType,
client: OpenAI | AzureOpenAI, client: OpenAI | AzureOpenAI,
max_iterations: int, max_iterations: int,
tracer: dict = None, tracer: dict = None,
@@ -26,9 +28,211 @@ class GroundingAgent:
self.client = client self.client = client
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.tracer = tracer self.tracer = tracer
self.environment_type = environment_type
self.action_tools = self.get_tools(self.environment_type)
# Define tools for the grounding LLM (OpenAI format) async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]:
self.action_tools = [ """Call the grounding LLM to get the next action based on the current state and instruction."""
# Format the message for the API call
messages_for_api = self._format_message_for_api(instruction, current_state)
try:
grounding_response: ChatCompletion = await self.client.chat.completions.create(
messages=messages_for_api,
model=self.model.name,
tools=self.action_tools,
tool_choice="required",
temperature=0.0, # Grounding should be precise
max_completion_tokens=1000, # Allow for thoughts + actions
)
if not isinstance(grounding_response, ChatCompletion):
raise ValueError("Grounding LLM response is not of type ChatCompletion.")
logger.debug(f"Grounding LLM response: {grounding_response.model_dump_json()}")
# Parse tool calls
grounding_message = grounding_response.choices[0].message
rendered_response, actions = self._parse_action(grounding_message, instruction, current_state)
# Update usage by grounding model
self.tracer["usage"] = get_chat_usage_metrics(
self.model.name,
input_tokens=grounding_response.usage.prompt_tokens,
output_tokens=grounding_response.usage.completion_tokens,
usage=self.tracer.get("usage"),
)
except Exception as e:
logger.error(f"Error calling Grounding LLM: {e}")
rendered_response = f"**Error**: Error contacting Grounding LLM: {e}"
actions = []
return rendered_response, actions
def _format_message_for_api(self, instruction: str, current_state: EnvState) -> List:
"""Format the message for the API call."""
# Construct grounding LLM input (using only the latest user prompt + image)
# We don't pass the full history here, as grounding depends on the *current* state + NL action
grounding_user_prompt = self.get_instruction(instruction, self.environment_type)
screenshots = [f"data:image/webp;base64,{current_state.screenshot}"]
grounding_messages_content = construct_structured_message(
grounding_user_prompt, screenshots, self.model.name, vision_enabled=True
)
return [{"role": "user", "content": grounding_messages_content}]
def _parse_action(
self, grounding_message: ChatCompletionMessage, instruction: str, current_state: EnvState
) -> tuple[str, list[OperatorAction]]:
"""Parse the tool calls from the grounding LLM response and convert them to action objects."""
actions: List[OperatorAction] = []
action_results: List[dict] = []
if grounding_message.tool_calls:
rendered_parts = []
for tool_call in grounding_message.tool_calls:
function_name = tool_call.function.name
try:
arguments = json.loads(tool_call.function.arguments)
action_to_run: Optional[OperatorAction] = None
action_render_str = f"**Action ({function_name})**: {tool_call.function.arguments}"
if function_name == "click":
action_to_run = ClickAction(**arguments)
elif function_name == "left_double":
action_to_run = DoubleClickAction(**arguments)
elif function_name == "right_single":
action_to_run = ClickAction(button="right", **arguments)
elif function_name == "type":
content = arguments.get("content")
action_to_run = TypeAction(text=content)
elif function_name == "scroll":
direction = arguments.get("direction", "down")
amount = 3
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, **arguments)
elif function_name == "hotkey":
action_to_run = KeypressAction(**arguments)
elif function_name == "goto":
action_to_run = GotoAction(**arguments)
elif function_name == "back":
action_to_run = BackAction(**arguments)
elif function_name == "wait":
action_to_run = WaitAction(**arguments)
elif function_name == "screenshot":
action_to_run = ScreenshotAction(**arguments)
elif function_name == "drag":
# Need to convert list of dicts to list of Point objects
path_dicts = arguments.get("path", [])
path_points = [Point(**p) for p in path_dicts]
if path_points:
action_to_run = DragAction(path=path_points)
else:
logger.warning(f"Drag action called with empty path: {arguments}")
action_render_str += " [Skipped - empty path]"
elif function_name == "finished":
action_to_run = None
else:
logger.warning(f"Grounding LLM called unhandled tool: {function_name}")
action_render_str += " [Unhandled]"
if action_to_run:
actions.append(action_to_run)
action_results.append(
{
"type": "tool_result",
"tool_call_id": tool_call.id,
"content": None, # Updated after environment step
}
)
rendered_parts.append(action_render_str)
except (json.JSONDecodeError, TypeError, ValueError) as arg_err:
logger.error(
f"Error parsing arguments for tool {function_name}: {arg_err} - Args: {tool_call.function.arguments}"
)
rendered_parts.append(f"**Error**: Failed to parse arguments for {function_name}")
rendered_response = "\n- ".join(rendered_parts)
else:
# Grounding LLM responded but didn't call a tool
logger.warning("Grounding LLM did not produce a tool call.")
rendered_response = f"{grounding_message.content or 'No action required.'}"
# Render the response
return rendered_response, actions
def get_instruction(self, instruction: str, environment_type: EnvironmentType) -> str:
"""
Get the instruction for the agent based on the environment type.
"""
UITARS_COMPUTER_PREFIX_PROMPT = """
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
"""
UITARS_BROWSER_PREFIX_PROMPT = """
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to decide the next action to complete the task.
You control a single tab in a Chromium browser. You cannot access the OS, filesystem or the application window.
Always use the `goto` function to navigate to a specific URL. Ctrl+t, Ctrl+w, Ctrl+q, Ctrl+Shift+T, Ctrl+Shift+W are not allowed.
"""
UITARS_USR_COMPUTER_PROMPT_THOUGHT = f"""
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes.
## Note
- Use English in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
UITARS_USR_BROWSER_PROMPT_THOUGHT = f"""
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes.
goto(url='xxx') # Always use this to navigate to a specific URL. Use escape characters \\', \\", and \\n in url part to ensure we can parse the url in normal python string format.
back() # Use this to go back to the previous page.
## Note
- Use English in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
if environment_type == EnvironmentType.BROWSER:
return dedent(UITARS_BROWSER_PREFIX_PROMPT + UITARS_USR_BROWSER_PROMPT_THOUGHT).lstrip()
elif environment_type == EnvironmentType.COMPUTER:
return dedent(UITARS_COMPUTER_PREFIX_PROMPT + UITARS_USR_COMPUTER_PROMPT_THOUGHT).lstrip()
else:
raise ValueError(f"Expected environment type: Computer or Browser. Got {environment_type}.")
def get_tools(self, environment_type: EnvironmentType) -> list[dict]:
"""Get tools for the grounding LLM, in OpenAI API tool format"""
tools = [
{ {
"type": "function", "type": "function",
"function": { "function": {
@@ -163,182 +367,32 @@ class GroundingAgent:
}, },
}, },
}, },
{ ]
"type": "function", if environment_type == EnvironmentType.BROWSER:
"function": { tools += [
"name": "goto", {
"description": "Navigate to a specific URL.", "type": "function",
"parameters": { "function": {
"type": "object", "name": "goto",
"properties": {"url": {"type": "string", "description": "Fully qualified URL"}}, "description": "Navigate to a specific URL.",
"required": ["url"], "parameters": {
"type": "object",
"properties": {"url": {"type": "string", "description": "Fully qualified URL"}},
"required": ["url"],
},
}, },
}, },
}, {
{ "type": "function",
"type": "function", "function": {
"function": { "name": "back",
"name": "back", "description": "navigate back to the previous page.",
"description": "navigate back to the previous page.", "parameters": {"type": "object", "properties": {}},
"parameters": {"type": "object", "properties": {}}, },
}, },
}, ]
]
async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]: return tools
"""Call the grounding LLM to get the next action based on the current state and instruction."""
# Format the message for the API call
messages_for_api = self._format_message_for_api(instruction, current_state)
try:
grounding_response: ChatCompletion = await self.client.chat.completions.create(
messages=messages_for_api,
model=self.model.name,
tools=self.action_tools,
tool_choice="required",
temperature=0.0, # Grounding should be precise
max_completion_tokens=1000, # Allow for thoughts + actions
)
if not isinstance(grounding_response, ChatCompletion):
raise ValueError("Grounding LLM response is not of type ChatCompletion.")
logger.debug(f"Grounding LLM response: {grounding_response.model_dump_json()}")
# Parse tool calls
grounding_message = grounding_response.choices[0].message
rendered_response, actions = self._parse_action(grounding_message, instruction, current_state)
# Update usage by grounding model
self.tracer["usage"] = get_chat_usage_metrics(
self.model.name,
input_tokens=grounding_response.usage.prompt_tokens,
output_tokens=grounding_response.usage.completion_tokens,
usage=self.tracer.get("usage"),
)
except Exception as e:
logger.error(f"Error calling Grounding LLM: {e}")
rendered_response = f"**Error**: Error contacting Grounding LLM: {e}"
actions = []
return rendered_response, actions
def _format_message_for_api(self, instruction: str, current_state: EnvState) -> List:
"""Format the message for the API call."""
grounding_user_prompt = f"""
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to decide the next action to complete the task.
You control a single tab in a Chromium browser. You cannot access the OS, filesystem or the application window.
Always use the `goto` function to navigate to a specific URL. Ctrl+t, Ctrl+w, Ctrl+q, Ctrl+Shift+T, Ctrl+Shift+W are not allowed.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes.
goto(url='xxx') # Always use this to navigate to a specific URL. Use escape characters \\', \\", and \\n in url part to ensure we can parse the url in normal python string format.
back() # Use this to go back to the previous page.
## Note
- Use English in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
""".lstrip()
# Construct grounding LLM input (using only the latest user prompt + image)
# We don't pass the full history here, as grounding depends on the *current* state + NL action
screenshots = [f"data:image/webp;base64,{current_state.screenshot}"]
grounding_messages_content = construct_structured_message(
grounding_user_prompt, screenshots, self.model.name, vision_enabled=True
)
return [{"role": "user", "content": grounding_messages_content}]
def _parse_action(
self, grounding_message: ChatCompletionMessage, instruction: str, current_state: EnvState
) -> tuple[str, list[OperatorAction]]:
"""Parse the tool calls from the grounding LLM response and convert them to action objects."""
actions: List[OperatorAction] = []
action_results: List[dict] = []
if grounding_message.tool_calls:
rendered_parts = []
for tool_call in grounding_message.tool_calls:
function_name = tool_call.function.name
try:
arguments = json.loads(tool_call.function.arguments)
action_to_run: Optional[OperatorAction] = None
action_render_str = f"**Action ({function_name})**: {tool_call.function.arguments}"
if function_name == "click":
action_to_run = ClickAction(**arguments)
elif function_name == "left_double":
action_to_run = DoubleClickAction(**arguments)
elif function_name == "right_single":
action_to_run = ClickAction(button="right", **arguments)
elif function_name == "type":
content = arguments.get("content")
action_to_run = TypeAction(text=content)
elif function_name == "scroll":
direction = arguments.get("direction", "down")
amount = 3
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, **arguments)
elif function_name == "hotkey":
action_to_run = KeypressAction(**arguments)
elif function_name == "goto":
action_to_run = GotoAction(**arguments)
elif function_name == "back":
action_to_run = BackAction(**arguments)
elif function_name == "wait":
action_to_run = WaitAction(**arguments)
elif function_name == "screenshot":
action_to_run = ScreenshotAction(**arguments)
elif function_name == "drag":
# Need to convert list of dicts to list of Point objects
path_dicts = arguments.get("path", [])
path_points = [Point(**p) for p in path_dicts]
if path_points:
action_to_run = DragAction(path=path_points)
else:
logger.warning(f"Drag action called with empty path: {arguments}")
action_render_str += " [Skipped - empty path]"
elif function_name == "finished":
action_to_run = None
else:
logger.warning(f"Grounding LLM called unhandled tool: {function_name}")
action_render_str += " [Unhandled]"
if action_to_run:
actions.append(action_to_run)
action_results.append(
{
"type": "tool_result",
"tool_call_id": tool_call.id,
"content": None, # Updated after environment step
}
)
rendered_parts.append(action_render_str)
except (json.JSONDecodeError, TypeError, ValueError) as arg_err:
logger.error(
f"Error parsing arguments for tool {function_name}: {arg_err} - Args: {tool_call.function.arguments}"
)
rendered_parts.append(f"**Error**: Failed to parse arguments for {function_name}")
rendered_response = "\n- ".join(rendered_parts)
else:
# Grounding LLM responded but didn't call a tool
logger.warning("Grounding LLM did not produce a tool call.")
rendered_response = f"{grounding_message.content or 'No action required.'}"
# Render the response
return rendered_response, actions
def reset(self): def reset(self):
"""Reset the agent state.""" """Reset the agent state."""

View File

@@ -10,6 +10,7 @@ import logging
import math import math
import re import re
from io import BytesIO from io import BytesIO
from textwrap import dedent
from typing import Any, List from typing import Any, List
import numpy as np import numpy as np
@@ -18,7 +19,7 @@ from openai.types.chat import ChatCompletion
from PIL import Image from PIL import Image
from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_environment_base import EnvState from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
from khoj.utils.helpers import get_chat_usage_metrics from khoj.utils.helpers import get_chat_usage_metrics
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,29 +36,8 @@ class GroundingAgentUitars:
MAX_PIXELS = 16384 * 28 * 28 MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200 MAX_RATIO = 200
UITARS_USR_PROMPT_THOUGHT = """ UITARS_NORMAL_ACTION_SPACE = dedent(
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to perform the next action to complete the task. """
You control a single tab in a Chromium browser. You cannot access the OS, filesystem, the application window or the addressbar.
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
{action_space}
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
UITARS_NORMAL_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>') click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
@@ -67,14 +47,15 @@ class GroundingAgentUitars:
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes. wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
""".lstrip() """
).lstrip()
def __init__( def __init__(
self, self,
model_name: str, model_name: str,
environment_type: EnvironmentType,
client: AsyncOpenAI | AsyncAzureOpenAI, client: AsyncOpenAI | AsyncAzureOpenAI,
max_iterations=50, max_iterations=50,
environment_type: Literal["computer", "web"] = "computer",
runtime_conf: dict = { runtime_conf: dict = {
"infer_mode": "qwen25vl_normal", "infer_mode": "qwen25vl_normal",
"prompt_style": "qwen25vl_normal", "prompt_style": "qwen25vl_normal",
@@ -94,7 +75,7 @@ class GroundingAgentUitars:
self.model_name = model_name self.model_name = model_name
self.client = client self.client = client
self.tracer = tracer self.tracer = tracer
self.environment_type = environment_type self.environment = environment_type
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.runtime_conf = runtime_conf self.runtime_conf = runtime_conf
@@ -116,7 +97,7 @@ class GroundingAgentUitars:
self.history_images: list[bytes] = [] self.history_images: list[bytes] = []
self.history_responses: list[str] = [] self.history_responses: list[str] = []
self.prompt_template = self.UITARS_USR_PROMPT_THOUGHT self.prompt_template = self.get_instruction(self.environment)
self.prompt_action_space = self.UITARS_NORMAL_ACTION_SPACE self.prompt_action_space = self.UITARS_NORMAL_ACTION_SPACE
if "history_n" in self.runtime_conf: if "history_n" in self.runtime_conf:
@@ -126,11 +107,11 @@ class GroundingAgentUitars:
self.cur_callusr_count = 0 self.cur_callusr_count = 0
async def act(self, instruction: str, env_state: EnvState) -> tuple[str, list[OperatorAction]]: async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]:
""" """
Suggest the next action(s) based on the instruction and current environment. Suggest the next action(s) based on the instruction and current environment.
""" """
messages = self._format_messages_for_api(instruction, env_state) messages = self._format_messages_for_api(instruction, current_state)
recent_screenshot = Image.open(BytesIO(self.history_images[-1])) recent_screenshot = Image.open(BytesIO(self.history_images[-1]))
origin_resized_height = recent_screenshot.height origin_resized_height = recent_screenshot.height
@@ -145,9 +126,11 @@ class GroundingAgentUitars:
try_times = 3 try_times = 3
while not parsed_responses: while not parsed_responses:
if try_times <= 0: if try_times <= 0:
print(f"Reach max retry times to fetch response from client, as error flag.") logger.warning(f"Reach max retry times to fetch response from client, as error flag.")
return "client error\nFAIL", [] return "client error\nFAIL", []
try: try:
message_content = "\n".join([msg["content"][0].get("text") or "[image]" for msg in messages])
logger.debug(f"User message content: {message_content}")
response: ChatCompletion = await self.client.chat.completions.create( response: ChatCompletion = await self.client.chat.completions.create(
model="ui-tars", model="ui-tars",
messages=messages, messages=messages,
@@ -228,20 +211,9 @@ class GroundingAgentUitars:
self.actions.append(actions) self.actions.append(actions)
return f"{prediction}\nFAIL", [] return f"{prediction}\nFAIL", []
if self.environment_type == "web": actions.extend(
actions.extend( self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap)
self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap) )
)
else:
pass
# TODO: Add PyautoguiAction when enable computer environment
# actions.append(
# PyautoguiAction(code=
# self.parsing_response_to_pyautogui_code(
# parsed_response, obs_image_height, obs_image_width, self.input_swap
# )
# )
# )
self.actions.append(actions) self.actions.append(actions)
@@ -252,13 +224,52 @@ class GroundingAgentUitars:
return prediction or "", actions return prediction or "", actions
def _format_messages_for_api(self, instruction: str, env_state: EnvState): def get_instruction(self, environment_type: EnvironmentType) -> str:
"""
Get the instruction for the agent based on the environment type.
"""
UITARS_COMPUTER_PREFIX_PROMPT = """
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
"""
UITARS_BROWSER_PREFIX_PROMPT = """
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to perform the next action to complete the task.
You control a single tab in a Chromium browser. You cannot access the OS, filesystem, the application window or the addressbar.
"""
UITARS_USR_PROMPT_THOUGHT = """
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
{action_space}
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
if environment_type == EnvironmentType.BROWSER:
return dedent(UITARS_BROWSER_PREFIX_PROMPT + UITARS_USR_PROMPT_THOUGHT).lstrip()
elif environment_type == EnvironmentType.COMPUTER:
return dedent(UITARS_COMPUTER_PREFIX_PROMPT + UITARS_USR_PROMPT_THOUGHT).lstrip()
else:
raise ValueError(f"Unsupported environment type: {environment_type}")
def _format_messages_for_api(self, instruction: str, current_state: EnvState):
assert len(self.observations) == len(self.actions) and len(self.actions) == len( assert len(self.observations) == len(self.actions) and len(self.actions) == len(
self.thoughts self.thoughts
), "The number of observations and actions should be the same." ), "The number of observations and actions should be the same."
self.history_images.append(base64.b64decode(env_state.screenshot)) self.history_images.append(base64.b64decode(current_state.screenshot))
self.observations.append({"screenshot": env_state.screenshot, "accessibility_tree": None}) self.observations.append({"screenshot": current_state.screenshot, "accessibility_tree": None})
user_prompt = self.prompt_template.format( user_prompt = self.prompt_template.format(
instruction=instruction, action_space=self.prompt_action_space, language=self.language instruction=instruction, action_space=self.prompt_action_space, language=self.language

View File

@@ -3,7 +3,8 @@ import json
import logging import logging
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from typing import List, Optional, cast from textwrap import dedent
from typing import List, Literal, Optional, cast
from anthropic.types.beta import BetaContentBlock from anthropic.types.beta import BetaContentBlock
@@ -14,7 +15,11 @@ from khoj.processor.operator.operator_agent_base import (
AgentMessage, AgentMessage,
OperatorAgent, OperatorAgent,
) )
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,
EnvStepResult,
)
from khoj.utils.helpers import get_anthropic_async_client, is_none_or_empty from khoj.utils.helpers import get_anthropic_async_client, is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -32,51 +37,12 @@ class AnthropicOperatorAgent(OperatorAgent):
action_results: List[dict] = [] action_results: List[dict] = []
self._commit_trace() # Commit trace before next action self._commit_trace() # Commit trace before next action
system_prompt = f"""<SYSTEM_CAPABILITY> system_prompt = self.get_instructions(self.environment_type, current_state)
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser. tools = self.get_tools(self.environment_type, current_state)
* You operate a Chromium browser using Playwright via the 'computer' tool.
* You cannot access the OS or filesystem.
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more.
* You can use the additional back() and goto() helper functions to ease navigating the browser. If you see nothing, try goto duckduckgo.com
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current URL is {current_state.url}.
</SYSTEM_CAPABILITY>
<IMPORTANT>
* You are allowed upto {self.max_iterations} iterations to complete the task.
* Do not loop on wait, screenshot for too many turns without taking any action.
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
</IMPORTANT>
"""
if is_none_or_empty(self.messages): if is_none_or_empty(self.messages):
self.messages = [AgentMessage(role="user", content=self.query)] self.messages = [AgentMessage(role="user", content=self.query)]
tools = [
{
"type": self.model_default_tool("computer"),
"name": "computer",
"display_width_px": 1024,
"display_height_px": 768,
}, # TODO: Get from env
{
"name": "back",
"description": "Go back to the previous page.",
"input_schema": {"type": "object", "properties": {}},
},
{
"name": "goto",
"description": "Go to a specific URL.",
"input_schema": {
"type": "object",
"properties": {"url": {"type": "string", "description": "Fully qualified URL to navigate to."}},
"required": ["url"],
},
},
]
thinking: dict[str, str | int] = {"type": "disabled"} thinking: dict[str, str | int] = {"type": "disabled"}
if is_reasoning_model(self.vision_model.name): if is_reasoning_model(self.vision_model.name):
thinking = {"type": "enabled", "budget_tokens": 1024} thinking = {"type": "enabled", "budget_tokens": 1024}
@@ -400,3 +366,80 @@ class AnthropicOperatorAgent(OperatorAgent):
return ["computer-use-2025-01-24"] return ["computer-use-2025-01-24"]
else: else:
return [] return []
def get_instructions(self, environment_type: EnvironmentType, current_state: EnvState) -> str:
"""Return system instructions for the Anthropic operator."""
if environment_type == EnvironmentType.BROWSER:
return dedent(
f"""
<SYSTEM_CAPABILITY>
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
* You operate a Chromium browser using Playwright via the 'computer' tool.
* You cannot access the OS or filesystem.
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more.
* You can use the additional back() and goto() helper functions to ease navigating the browser. If you see nothing, try goto duckduckgo.com
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current URL is {current_state.url}.
</SYSTEM_CAPABILITY>
<IMPORTANT>
* You are allowed upto {self.max_iterations} iterations to complete the task.
* Do not loop on wait, screenshot for too many turns without taking any action.
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
</IMPORTANT>
"""
).lstrip()
elif environment_type == EnvironmentType.COMPUTER:
return dedent(
f"""
<SYSTEM_CAPABILITY>
* You are Khoj, a smart computer operating assistant. You help the users accomplish tasks using a computer.
* You can interact with the computer to perform tasks like clicking, typing, scrolling, and more.
* When viewing a document or webpage it can be helpful to zoom out or scroll down to ensure you see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* Do not loop on wait, screenshot for too many turns without taking any action.
* You are allowed upto {self.max_iterations} iterations to complete the task.
</SYSTEM_CAPABILITY>
<CONTEXT>
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
</CONTEXT>
"""
).lstrip()
else:
raise ValueError(f"Unsupported environment type for Anthropic operator: {environment_type}")
def get_tools(self, environment: EnvironmentType, current_state: EnvState) -> list[dict]:
"""Return the tools available for the Anthropic operator."""
tools = [
{
"type": self.model_default_tool("computer"),
"name": "computer",
"display_width_px": current_state.width,
"display_height_px": current_state.height,
}
]
if environment == "browser":
tools += [
{
"name": "back",
"description": "Go back to the previous page.",
"input_schema": {"type": "object", "properties": {}},
},
{
"name": "goto",
"description": "Go to a specific URL.",
"input_schema": {
"type": "object",
"properties": {"url": {"type": "string", "description": "Fully qualified URL to navigate to."}},
"required": ["url"],
},
},
]
return tools

View File

@@ -7,7 +7,11 @@ from pydantic import BaseModel
from khoj.database.models import ChatModel from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import commit_conversation_trace from khoj.processor.conversation.utils import commit_conversation_trace
from khoj.processor.operator.operator_actions import OperatorAction from khoj.processor.operator.operator_actions import OperatorAction
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,
EnvStepResult,
)
from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,9 +29,12 @@ class AgentMessage(BaseModel):
class OperatorAgent(ABC): class OperatorAgent(ABC):
def __init__(self, query: str, vision_model: ChatModel, max_iterations: int, tracer: dict): def __init__(
self, query: str, vision_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, tracer: dict
):
self.query = query self.query = query
self.vision_model = vision_model self.vision_model = vision_model
self.environment_type = environment_type
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.tracer = tracer self.tracer = tracer
self.messages: List[AgentMessage] = [] self.messages: List[AgentMessage] = []

View File

@@ -1,9 +1,8 @@
import json import json
import logging import logging
from datetime import datetime from datetime import datetime
from typing import List, Optional from textwrap import dedent
from typing import List
from openai.types.chat import ChatCompletion
from khoj.database.models import ChatModel from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import construct_structured_message from khoj.processor.conversation.utils import construct_structured_message
@@ -15,7 +14,11 @@ from khoj.processor.operator.operator_agent_base import (
AgentMessage, AgentMessage,
OperatorAgent, OperatorAgent,
) )
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,
EnvStepResult,
)
from khoj.routers.helpers import send_message_to_model_wrapper from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import get_openai_async_client, is_none_or_empty from khoj.utils.helpers import get_openai_async_client, is_none_or_empty
@@ -27,7 +30,7 @@ class BinaryOperatorAgent(OperatorAgent):
""" """
An OperatorAgent that uses two LLMs: An OperatorAgent that uses two LLMs:
1. Reasoning LLM: Determines the next high-level action based on the objective and current visual reasoning trajectory. 1. Reasoning LLM: Determines the next high-level action based on the objective and current visual reasoning trajectory.
2. Grounding LLM: Converts the high-level action into specific, executable browser actions. 2. Grounding LLM: Converts the high-level action into specific, actions executable on the environment.
""" """
def __init__( def __init__(
@@ -35,10 +38,13 @@ class BinaryOperatorAgent(OperatorAgent):
query: str, query: str,
reasoning_model: ChatModel, reasoning_model: ChatModel,
grounding_model: ChatModel, grounding_model: ChatModel,
environment_type: EnvironmentType,
max_iterations: int, max_iterations: int,
tracer: dict, tracer: dict,
): ):
super().__init__(query, reasoning_model, max_iterations, tracer) # Use reasoning model for primary tracking super().__init__(
query, reasoning_model, environment_type, max_iterations, tracer
) # Use reasoning model for primary tracking
self.reasoning_model = reasoning_model self.reasoning_model = reasoning_model
self.grounding_model = grounding_model self.grounding_model = grounding_model
# Initialize openai api compatible client for grounding model # Initialize openai api compatible client for grounding model
@@ -49,10 +55,12 @@ class BinaryOperatorAgent(OperatorAgent):
self.grounding_agent: GroundingAgent | GroundingAgentUitars = None self.grounding_agent: GroundingAgent | GroundingAgentUitars = None
if "ui-tars-1.5" in grounding_model.name: if "ui-tars-1.5" in grounding_model.name:
self.grounding_agent = GroundingAgentUitars( self.grounding_agent = GroundingAgentUitars(
grounding_model.name, grounding_client, max_iterations, environment_type="web", tracer=tracer grounding_model.name, self.environment_type, grounding_client, max_iterations, tracer=tracer
) )
else: else:
self.grounding_agent = GroundingAgent(grounding_model.name, grounding_client, max_iterations, tracer=tracer) self.grounding_agent = GroundingAgent(
grounding_model.name, self.environment_type, grounding_client, max_iterations, tracer=tracer
)
async def act(self, current_state: EnvState) -> AgentActResult: async def act(self, current_state: EnvState) -> AgentActResult:
""" """
@@ -84,48 +92,7 @@ class BinaryOperatorAgent(OperatorAgent):
""" """
Uses the reasoning LLM to determine the next high-level action based on the operation trajectory. Uses the reasoning LLM to determine the next high-level action based on the operation trajectory.
""" """
reasoning_system_prompt = f""" reasoning_system_prompt = self.get_instruction(self.environment_type, current_state)
# Introduction
* You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser.
* You are given the user's query and screenshots of the browser's state transitions.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current URL is {current_state.url}.
# Your Task
* First look at the screenshots carefully to notice all pertinent information.
* Then instruct a tool AI to perform the next action that will help you progress towards the user's goal.
* Make sure you scroll down to see everything before deciding something isn't available.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* Use your creativity to find alternate ways to make progress if you get stuck at any point.
# Tool AI Capabilities
* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform the next action on the page.
* It can interact with the web browser with these actions: click, right click, double click, type, scroll, drag, wait, goto url and go back to previous page.
* It cannot access the OS, filesystem or application window. It just controls a single Chromium browser tab via Playwright.
# IMPORTANT
* You are allowed upto {self.max_iterations} iterations to complete the task.
* To navigate to a specific URL, put "GOTO <URL>" (without quotes) on the last line of your response.
* To navigate back to the previous page, end your response with "BACK" (without quotes).
* Once you've verified that the main objective has been achieved, end your response with "DONE" (without quotes).
# Examples
## Example 1
GOTO https://example.com
## Example 2
click the blue login button located at the top right corner
## Example 3
scroll down the page
## Example 4
type the username example@email.com into the input field labeled Username
## Example 5
DONE
# Instructions
Now describe a single high-level action to take next to progress towards the user's goal in detail.
Focus on the visual action and provide all necessary context.
""".strip()
if is_none_or_empty(self.messages): if is_none_or_empty(self.messages):
query_text = f"**Main Objective**: {self.query}" query_text = f"**Main Objective**: {self.query}"
query_screenshot = [f"data:image/webp;base64,{current_state.screenshot}"] query_screenshot = [f"data:image/webp;base64,{current_state.screenshot}"]
@@ -330,6 +297,96 @@ Focus on the visual action and provide all necessary context.
] ]
return formatted_messages return formatted_messages
def get_instruction(self, environment_type: EnvironmentType, env_state: EnvState) -> str:
"""Get the system instruction for the reasoning agent."""
if environment_type == EnvironmentType.BROWSER:
return dedent(
f"""
# Introduction
* You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser.
* You are given the user's query and screenshots of the browser's state transitions.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current URL is {env_state.url}.
# Your Task
* First look at the screenshots carefully to notice all pertinent information.
* Then instruct a tool AI to perform the next action that will help you progress towards the user's goal.
* Make sure you scroll down to see everything before deciding something isn't available.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* Use your creativity to find alternate ways to make progress if you get stuck at any point.
# Tool AI Capabilities
* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform the next action on the page.
* It can interact with the web browser with these actions: click, right click, double click, type, scroll, drag, wait, goto url and go back to previous page.
* It cannot access the OS, filesystem or application window. It just controls a single Chromium browser tab via Playwright.
# IMPORTANT
* You are allowed upto {self.max_iterations} iterations to complete the task.
* To navigate to a specific URL, put "GOTO <URL>" (without quotes) on the last line of your response.
* To navigate back to the previous page, end your response with "BACK" (without quotes).
* Once you've verified that the main objective has been achieved, end your response with "DONE" (without quotes).
# Examples
## Example 1
GOTO https://example.com
## Example 2
click the blue login button located at the top right corner
## Example 3
scroll down the page
## Example 4
type the username example@email.com into the input field labeled Username
## Example 5
DONE
# Instructions
Now describe a single high-level action to take next to progress towards the user's goal in detail.
Focus on the visual action and provide all necessary context.
"""
).strip()
elif environment_type == EnvironmentType.COMPUTER:
return dedent(
f"""
# Introduction
* You are Khoj, a smart and resourceful computer assistant. You help the user accomplish their task using a computer.
* You are given the user's query and screenshots of the computer's state transitions.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
# Your Task
* First look at the screenshots carefully to notice all pertinent information.
* Then instruct a tool AI to perform the next action that will help you progress towards the user's goal.
* Make sure you scroll down to see everything before deciding something isn't available.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* Use your creativity to find alternate ways to make progress if you get stuck at any point.
# Tool AI Capabilities
* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform the next action on the page.
* It can interact with the computer with these actions: click, right click, double click, type, scroll, drag, wait to previous page.
# IMPORTANT
* You are allowed upto {self.max_iterations} iterations to complete the task.
* Once you've verified that the main objective has been achieved, end your response with "DONE" (without quotes).
# Examples
## Example 1
type https://example.com into the address bar and press Enter
## Example 2
click the blue login button located at the top right corner
## Example 3
scroll down the page
## Example 4
type the username example@email.com into the input field labeled Username
## Example 5
DONE
# Instructions
Now describe a single high-level action to take next to progress towards the user's goal in detail.
Focus on the visual action and provide all necessary context.
"""
).strip()
else:
raise ValueError(f"Expected environment type: Computer or Browser. Got {environment_type}.")
def reset(self): def reset(self):
"""Reset the agent state.""" """Reset the agent state."""
super().reset() super().reset()

View File

@@ -1,7 +1,9 @@
import json import json
import logging import logging
import platform
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from textwrap import dedent
from typing import List, Optional, cast from typing import List, Optional, cast
from openai.types.responses import Response, ResponseOutputItem from openai.types.responses import Response, ResponseOutputItem
@@ -12,7 +14,11 @@ from khoj.processor.operator.operator_agent_base import (
AgentMessage, AgentMessage,
OperatorAgent, OperatorAgent,
) )
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult from khoj.processor.operator.operator_environment_base import (
EnvironmentType,
EnvState,
EnvStepResult,
)
from khoj.utils.helpers import get_openai_async_client, is_none_or_empty from khoj.utils.helpers import get_openai_async_client, is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,55 +35,8 @@ class OpenAIOperatorAgent(OperatorAgent):
actions: List[OperatorAction] = [] actions: List[OperatorAction] = []
action_results: List[dict] = [] action_results: List[dict] = []
self._commit_trace() # Commit trace before next action self._commit_trace() # Commit trace before next action
system_prompt = f"""<SYSTEM_CAPABILITY> system_prompt = self.get_instructions(self.environment_type, current_state)
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser. tools = self.get_tools(self.environment_type, current_state)
* You operate a single Chromium browser page using Playwright.
* You cannot access the OS or filesystem.
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more using the computer_use_preview tool.
* You can use the additional back() and goto() functions to navigate the browser.
* Always use the goto() function to navigate to a specific URL. If you see nothing, try goto duckduckgo.com
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current URL is {current_state.url}.
</SYSTEM_CAPABILITY>
<IMPORTANT>
* You are allowed upto {self.max_iterations} iterations to complete the task.
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
</IMPORTANT>
"""
tools = [
{
"type": "computer_use_preview",
"display_width": 1024, # TODO: Get from env
"display_height": 768, # TODO: Get from env
"environment": "browser",
},
{
"type": "function",
"name": "back",
"description": "Go back to the previous page.",
"parameters": {},
},
{
"type": "function",
"name": "goto",
"description": "Go to a specific URL.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "Fully qualified URL to navigate to.",
},
},
"additionalProperties": False,
"required": ["url"],
},
},
]
if is_none_or_empty(self.messages): if is_none_or_empty(self.messages):
self.messages = [AgentMessage(role="user", content=self.query)] self.messages = [AgentMessage(role="user", content=self.query)]
@@ -347,3 +306,93 @@ class OpenAIOperatorAgent(OperatorAgent):
} }
return render_payload return render_payload
def get_instructions(self, environment_type: EnvironmentType, current_state: EnvState) -> str:
"""Return system instructions for the OpenAI operator."""
if environment_type == EnvironmentType.BROWSER:
return dedent(
f"""
<SYSTEM_CAPABILITY>
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
* You operate a single Chromium browser page using Playwright.
* You cannot access the OS or filesystem.
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more using the computer_use_preview tool.
* You can use the additional back() and goto() functions to navigate the browser.
* Always use the goto() function to navigate to a specific URL. If you see nothing, try goto duckduckgo.com
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
* The current URL is {current_state.url}.
</SYSTEM_CAPABILITY>
<IMPORTANT>
* You are allowed upto {self.max_iterations} iterations to complete the task.
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
</IMPORTANT>
"""
).lstrip()
elif environment_type == EnvironmentType.COMPUTER:
return dedent(
f"""
<SYSTEM_CAPABILITY>
* You are Khoj, a smart computer operating assistant. You help the users accomplish their tasks using a computer.
* You can interact with the computer to perform tasks like clicking, typing, scrolling, and more using the computer_use_preview tool.
* When viewing a document or webpage it can be helpful to zoom out or scroll down to ensure you see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
* You are allowed upto {self.max_iterations} iterations to complete the task.
</SYSTEM_CAPABILITY>
<CONTEXT>
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
</CONTEXT>
"""
).lstrip()
else:
raise ValueError(f"Unsupported environment type: {environment_type}")
def get_tools(self, environment_type: EnvironmentType, current_state: EnvState) -> list[dict]:
"""Return the tools available for the OpenAI operator."""
if environment_type == EnvironmentType.COMPUTER:
# get os info of this computer. it can be mac, windows, linux
environment_os = (
"mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux"
)
else:
environment_os = "browser"
tools = [
{
"type": "computer_use_preview",
"display_width": current_state.width,
"display_height": current_state.height,
"environment": environment_os,
}
]
if environment_type == EnvironmentType.BROWSER:
tools += [
{
"type": "function",
"name": "back",
"description": "Go back to the previous page.",
"parameters": {},
},
{
"type": "function",
"name": "goto",
"description": "Go to a specific URL.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "Fully qualified URL to navigate to.",
},
},
"additionalProperties": False,
"required": ["url"],
},
},
]
return tools

View File

@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum
from typing import Literal, Optional from typing import Literal, Optional
from pydantic import BaseModel from pydantic import BaseModel
@@ -6,9 +7,18 @@ from pydantic import BaseModel
from khoj.processor.operator.operator_actions import OperatorAction from khoj.processor.operator.operator_actions import OperatorAction
class EnvironmentType(Enum):
"""Type of environment to operate."""
COMPUTER = "computer"
BROWSER = "browser"
class EnvState(BaseModel): class EnvState(BaseModel):
url: str height: int
width: int
screenshot: Optional[str] = None screenshot: Optional[str] = None
url: Optional[str] = None
class EnvStepResult(BaseModel): class EnvStepResult(BaseModel):

View File

@@ -124,10 +124,10 @@ class BrowserEnvironment(Environment):
async def get_state(self) -> EnvState: async def get_state(self) -> EnvState:
if not self.page or self.page.is_closed(): if not self.page or self.page.is_closed():
return EnvState(url="about:blank", screenshot=None) return EnvState(url="about:blank", screenshot=None, height=self.height, width=self.width)
url = self.page.url url = self.page.url
screenshot = await self._get_screenshot() screenshot = await self._get_screenshot()
return EnvState(url=url, screenshot=screenshot) return EnvState(url=url, screenshot=screenshot, height=self.height, width=self.width)
async def step(self, action: OperatorAction) -> EnvStepResult: async def step(self, action: OperatorAction) -> EnvStepResult:
if not self.page or self.page.is_closed(): if not self.page or self.page.is_closed():

View File

@@ -31,7 +31,7 @@ from khoj.processor.conversation.utils import (
save_to_conversation_log, save_to_conversation_log,
) )
from khoj.processor.image.generate import text_to_image from khoj.processor.image.generate import text_to_image
from khoj.processor.operator.operate_browser import operate_browser from khoj.processor.operator import operate_environment
from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import ( from khoj.processor.tools.online_search import (
deduplicate_organic_results, deduplicate_organic_results,
@@ -1292,7 +1292,7 @@ async def chat(
) )
if ConversationCommand.Operator in conversation_commands: if ConversationCommand.Operator in conversation_commands:
try: try:
async for result in operate_browser( async for result in operate_environment(
defiltered_query, defiltered_query,
user, user,
meta_log, meta_log,

View File

@@ -18,7 +18,7 @@ from khoj.processor.conversation.utils import (
construct_tool_chat_history, construct_tool_chat_history,
load_complex_json, load_complex_json,
) )
from khoj.processor.operator.operate_browser import operate_browser from khoj.processor.operator import operate_environment
from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code from khoj.processor.tools.run_code import run_code
from khoj.routers.api import extract_references_and_questions from khoj.routers.api import extract_references_and_questions
@@ -417,12 +417,12 @@ async def execute_information_collection(
elif this_iteration.tool == ConversationCommand.Operator: elif this_iteration.tool == ConversationCommand.Operator:
try: try:
async for result in operate_browser( async for result in operate_environment(
this_iteration.query, this_iteration.query,
user, user,
construct_tool_chat_history(previous_iterations, ConversationCommand.Operator), construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
location, location,
send_status_func, send_status_func=send_status_func,
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
query_files=query_files, query_files=query_files,