mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Generalize operator to operate multiple types of environment
Previously it could only operate a (playwright) browser. Now - The operator logic and naming has been updated assuming multiple environment types can be operated - The operator entrypoint is now at __init__.py to simplify imports and the entrypoint function is called operate_environment - All operator agents have been updated to select their system prompts and tools based on the environment they'll operate
This commit is contained in:
@@ -11,7 +11,10 @@ from khoj.processor.operator.operator_agent_anthropic import AnthropicOperatorAg
|
||||
from khoj.processor.operator.operator_agent_base import OperatorAgent
|
||||
from khoj.processor.operator.operator_agent_binary import BinaryOperatorAgent
|
||||
from khoj.processor.operator.operator_agent_openai import OpenAIOperatorAgent
|
||||
from khoj.processor.operator.operator_environment_base import EnvStepResult
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
EnvStepResult,
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_browser import BrowserEnvironment
|
||||
from khoj.routers.helpers import ChatEvent
|
||||
from khoj.utils.helpers import timer
|
||||
@@ -20,12 +23,13 @@ from khoj.utils.rawconfig import LocationData
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --- Browser Operator Function ---
|
||||
async def operate_browser(
|
||||
# --- Main Operator Entrypoint ---
|
||||
async def operate_environment(
|
||||
query: str,
|
||||
user: KhojUser,
|
||||
conversation_log: dict,
|
||||
location_data: LocationData,
|
||||
environment_type: EnvironmentType = EnvironmentType.COMPUTER,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None, # TODO: Handle query images
|
||||
agent: Agent = None,
|
||||
@@ -34,7 +38,6 @@ async def operate_browser(
|
||||
tracer: dict = {},
|
||||
):
|
||||
response, summary_message, user_input_message = None, None, None
|
||||
environment: Optional[BrowserEnvironment] = None
|
||||
|
||||
# Get the agent chat model
|
||||
agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None
|
||||
@@ -42,15 +45,15 @@ async def operate_browser(
|
||||
if not reasoning_model or not reasoning_model.vision_enabled:
|
||||
reasoning_model = await ConversationAdapters.aget_vision_enabled_config()
|
||||
if not reasoning_model:
|
||||
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate browser.")
|
||||
raise ValueError(f"No vision enabled chat model found. Configure a vision chat model to operate environment.")
|
||||
|
||||
# Initialize Agent
|
||||
max_iterations = int(os.getenv("KHOJ_OPERATOR_ITERATIONS", 40))
|
||||
operator_agent: OperatorAgent
|
||||
if is_operator_model(reasoning_model.name) == ChatModel.ModelType.OPENAI:
|
||||
operator_agent = OpenAIOperatorAgent(query, reasoning_model, max_iterations, tracer)
|
||||
operator_agent = OpenAIOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer)
|
||||
elif is_operator_model(reasoning_model.name) == ChatModel.ModelType.ANTHROPIC:
|
||||
operator_agent = AnthropicOperatorAgent(query, reasoning_model, max_iterations, tracer)
|
||||
operator_agent = AnthropicOperatorAgent(query, reasoning_model, environment_type, max_iterations, tracer)
|
||||
else:
|
||||
grounding_model_name = "ui-tars-1.5"
|
||||
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
|
||||
@@ -60,11 +63,13 @@ async def operate_browser(
|
||||
or not grounding_model.model_type == ChatModel.ModelType.OPENAI
|
||||
):
|
||||
raise ValueError("No supported visual grounding model for binary operator agent found.")
|
||||
operator_agent = BinaryOperatorAgent(query, reasoning_model, grounding_model, max_iterations, tracer)
|
||||
operator_agent = BinaryOperatorAgent(
|
||||
query, reasoning_model, grounding_model, environment_type, max_iterations, tracer
|
||||
)
|
||||
|
||||
# Initialize Environment
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Launching Browser**"):
|
||||
async for event in send_status_func(f"**Launching {environment_type.value}**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
environment = BrowserEnvironment()
|
||||
await environment.start(width=1024, height=768)
|
||||
@@ -75,25 +80,27 @@ async def operate_browser(
|
||||
task_completed = False
|
||||
iterations = 0
|
||||
|
||||
with timer(f"Operating browser with {reasoning_model.model_type} {reasoning_model.name}", logger):
|
||||
with timer(
|
||||
f"Operating {environment_type.value} with {reasoning_model.model_type} {reasoning_model.name}", logger
|
||||
):
|
||||
while iterations < max_iterations and not task_completed:
|
||||
if cancellation_event and cancellation_event.is_set():
|
||||
logger.debug(f"Browser operator cancelled by client disconnect")
|
||||
logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
|
||||
break
|
||||
|
||||
iterations += 1
|
||||
|
||||
# 1. Get current environment state
|
||||
browser_state = await environment.get_state()
|
||||
env_state = await environment.get_state()
|
||||
|
||||
# 2. Agent decides action(s)
|
||||
agent_result = await operator_agent.act(browser_state)
|
||||
agent_result = await operator_agent.act(env_state)
|
||||
|
||||
# 3. Execute actions in the environment
|
||||
env_steps: List[EnvStepResult] = []
|
||||
for action in agent_result.actions:
|
||||
if cancellation_event and cancellation_event.is_set():
|
||||
logger.debug(f"Browser operator cancelled by client disconnect")
|
||||
logger.debug(f"{environment_type.value} operator cancelled by client disconnect")
|
||||
break
|
||||
# Handle request for user action and break the loop
|
||||
if isinstance(action, RequestUserAction):
|
||||
@@ -106,12 +113,14 @@ async def operate_browser(
|
||||
env_steps.append(env_step)
|
||||
|
||||
# Render status update
|
||||
latest_screenshot = f"data:image/webp;base64,{env_steps[-1].screenshot_base64 if env_steps else browser_state.screenshot}"
|
||||
latest_screenshot = (
|
||||
f"data:image/webp;base64,{env_steps[-1].screenshot_base64 if env_steps else env_state.screenshot}"
|
||||
)
|
||||
render_payload = agent_result.rendered_response
|
||||
render_payload["image"] = latest_screenshot
|
||||
render_content = f"**Action**: {json.dumps(render_payload)}"
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Operating Browser**:\n{render_content}"):
|
||||
async for event in send_status_func(f"**Operating {environment_type.value}**:\n{render_content}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
# Check if termination conditions are met
|
||||
@@ -123,7 +132,7 @@ async def operate_browser(
|
||||
if task_completed or trigger_iteration_limit:
|
||||
# Summarize results of operator run on last iteration
|
||||
operator_agent.add_action_results(env_steps, agent_result)
|
||||
summary_message = await operator_agent.summarize(summarize_prompt, browser_state)
|
||||
summary_message = await operator_agent.summarize(summarize_prompt, env_state)
|
||||
logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}")
|
||||
break
|
||||
|
||||
@@ -138,15 +147,19 @@ async def operate_browser(
|
||||
else: # Hit iteration limit
|
||||
response = f"Operator hit iteration limit ({max_iterations}). If the results seem incomplete try again, assign a smaller task or try a different approach.\nThese were the results till now:\n{summary_message}"
|
||||
finally:
|
||||
if environment and not user_input_message: # Don't close browser if user input required
|
||||
if environment and not user_input_message: # Don't close environment if user input required
|
||||
await environment.close()
|
||||
if operator_agent:
|
||||
operator_agent.reset()
|
||||
|
||||
webpages = []
|
||||
if environment_type == EnvironmentType.BROWSER and hasattr(environment, "visited_urls"):
|
||||
webpages = [{"link": url, "snippet": ""} for url in environment.visited_urls]
|
||||
|
||||
yield {
|
||||
"query": query,
|
||||
"result": user_input_message or response,
|
||||
"webpages": [{"link": url, "snippet": ""} for url in environment.visited_urls],
|
||||
"webpages": webpages,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from textwrap import dedent
|
||||
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
@@ -8,7 +9,7 @@ from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import construct_structured_message
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_agent_base import AgentActResult
|
||||
from khoj.processor.operator.operator_environment_base import EnvState
|
||||
from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
|
||||
from khoj.utils.helpers import get_chat_usage_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,6 +19,7 @@ class GroundingAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: ChatModel,
|
||||
environment_type: EnvironmentType,
|
||||
client: OpenAI | AzureOpenAI,
|
||||
max_iterations: int,
|
||||
tracer: dict = None,
|
||||
@@ -26,9 +28,211 @@ class GroundingAgent:
|
||||
self.client = client
|
||||
self.max_iterations = max_iterations
|
||||
self.tracer = tracer
|
||||
self.environment_type = environment_type
|
||||
self.action_tools = self.get_tools(self.environment_type)
|
||||
|
||||
# Define tools for the grounding LLM (OpenAI format)
|
||||
self.action_tools = [
|
||||
async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]:
|
||||
"""Call the grounding LLM to get the next action based on the current state and instruction."""
|
||||
# Format the message for the API call
|
||||
messages_for_api = self._format_message_for_api(instruction, current_state)
|
||||
try:
|
||||
grounding_response: ChatCompletion = await self.client.chat.completions.create(
|
||||
messages=messages_for_api,
|
||||
model=self.model.name,
|
||||
tools=self.action_tools,
|
||||
tool_choice="required",
|
||||
temperature=0.0, # Grounding should be precise
|
||||
max_completion_tokens=1000, # Allow for thoughts + actions
|
||||
)
|
||||
if not isinstance(grounding_response, ChatCompletion):
|
||||
raise ValueError("Grounding LLM response is not of type ChatCompletion.")
|
||||
logger.debug(f"Grounding LLM response: {grounding_response.model_dump_json()}")
|
||||
|
||||
# Parse tool calls
|
||||
grounding_message = grounding_response.choices[0].message
|
||||
rendered_response, actions = self._parse_action(grounding_message, instruction, current_state)
|
||||
|
||||
# Update usage by grounding model
|
||||
self.tracer["usage"] = get_chat_usage_metrics(
|
||||
self.model.name,
|
||||
input_tokens=grounding_response.usage.prompt_tokens,
|
||||
output_tokens=grounding_response.usage.completion_tokens,
|
||||
usage=self.tracer.get("usage"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Grounding LLM: {e}")
|
||||
rendered_response = f"**Error**: Error contacting Grounding LLM: {e}"
|
||||
actions = []
|
||||
|
||||
return rendered_response, actions
|
||||
|
||||
def _format_message_for_api(self, instruction: str, current_state: EnvState) -> List:
|
||||
"""Format the message for the API call."""
|
||||
# Construct grounding LLM input (using only the latest user prompt + image)
|
||||
# We don't pass the full history here, as grounding depends on the *current* state + NL action
|
||||
grounding_user_prompt = self.get_instruction(instruction, self.environment_type)
|
||||
screenshots = [f"data:image/webp;base64,{current_state.screenshot}"]
|
||||
grounding_messages_content = construct_structured_message(
|
||||
grounding_user_prompt, screenshots, self.model.name, vision_enabled=True
|
||||
)
|
||||
return [{"role": "user", "content": grounding_messages_content}]
|
||||
|
||||
def _parse_action(
|
||||
self, grounding_message: ChatCompletionMessage, instruction: str, current_state: EnvState
|
||||
) -> tuple[str, list[OperatorAction]]:
|
||||
"""Parse the tool calls from the grounding LLM response and convert them to action objects."""
|
||||
actions: List[OperatorAction] = []
|
||||
action_results: List[dict] = []
|
||||
|
||||
if grounding_message.tool_calls:
|
||||
rendered_parts = []
|
||||
for tool_call in grounding_message.tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
try:
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
action_to_run: Optional[OperatorAction] = None
|
||||
action_render_str = f"**Action ({function_name})**: {tool_call.function.arguments}"
|
||||
|
||||
if function_name == "click":
|
||||
action_to_run = ClickAction(**arguments)
|
||||
elif function_name == "left_double":
|
||||
action_to_run = DoubleClickAction(**arguments)
|
||||
elif function_name == "right_single":
|
||||
action_to_run = ClickAction(button="right", **arguments)
|
||||
elif function_name == "type":
|
||||
content = arguments.get("content")
|
||||
action_to_run = TypeAction(text=content)
|
||||
elif function_name == "scroll":
|
||||
direction = arguments.get("direction", "down")
|
||||
amount = 3
|
||||
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, **arguments)
|
||||
elif function_name == "hotkey":
|
||||
action_to_run = KeypressAction(**arguments)
|
||||
elif function_name == "goto":
|
||||
action_to_run = GotoAction(**arguments)
|
||||
elif function_name == "back":
|
||||
action_to_run = BackAction(**arguments)
|
||||
elif function_name == "wait":
|
||||
action_to_run = WaitAction(**arguments)
|
||||
elif function_name == "screenshot":
|
||||
action_to_run = ScreenshotAction(**arguments)
|
||||
elif function_name == "drag":
|
||||
# Need to convert list of dicts to list of Point objects
|
||||
path_dicts = arguments.get("path", [])
|
||||
path_points = [Point(**p) for p in path_dicts]
|
||||
if path_points:
|
||||
action_to_run = DragAction(path=path_points)
|
||||
else:
|
||||
logger.warning(f"Drag action called with empty path: {arguments}")
|
||||
action_render_str += " [Skipped - empty path]"
|
||||
elif function_name == "finished":
|
||||
action_to_run = None
|
||||
else:
|
||||
logger.warning(f"Grounding LLM called unhandled tool: {function_name}")
|
||||
action_render_str += " [Unhandled]"
|
||||
|
||||
if action_to_run:
|
||||
actions.append(action_to_run)
|
||||
action_results.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": None, # Updated after environment step
|
||||
}
|
||||
)
|
||||
rendered_parts.append(action_render_str)
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as arg_err:
|
||||
logger.error(
|
||||
f"Error parsing arguments for tool {function_name}: {arg_err} - Args: {tool_call.function.arguments}"
|
||||
)
|
||||
rendered_parts.append(f"**Error**: Failed to parse arguments for {function_name}")
|
||||
rendered_response = "\n- ".join(rendered_parts)
|
||||
else:
|
||||
# Grounding LLM responded but didn't call a tool
|
||||
logger.warning("Grounding LLM did not produce a tool call.")
|
||||
rendered_response = f"{grounding_message.content or 'No action required.'}"
|
||||
|
||||
# Render the response
|
||||
return rendered_response, actions
|
||||
|
||||
def get_instruction(self, instruction: str, environment_type: EnvironmentType) -> str:
|
||||
"""
|
||||
Get the instruction for the agent based on the environment type.
|
||||
"""
|
||||
UITARS_COMPUTER_PREFIX_PROMPT = """
|
||||
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
"""
|
||||
UITARS_BROWSER_PREFIX_PROMPT = """
|
||||
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to decide the next action to complete the task.
|
||||
You control a single tab in a Chromium browser. You cannot access the OS, filesystem or the application window.
|
||||
Always use the `goto` function to navigate to a specific URL. Ctrl+t, Ctrl+w, Ctrl+q, Ctrl+Shift+T, Ctrl+Shift+W are not allowed.
|
||||
"""
|
||||
|
||||
UITARS_USR_COMPUTER_PROMPT_THOUGHT = f"""
|
||||
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes.
|
||||
|
||||
## Note
|
||||
- Use English in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
UITARS_USR_BROWSER_PROMPT_THOUGHT = f"""
|
||||
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes.
|
||||
goto(url='xxx') # Always use this to navigate to a specific URL. Use escape characters \\', \\", and \\n in url part to ensure we can parse the url in normal python string format.
|
||||
back() # Use this to go back to the previous page.
|
||||
|
||||
## Note
|
||||
- Use English in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
return dedent(UITARS_BROWSER_PREFIX_PROMPT + UITARS_USR_BROWSER_PROMPT_THOUGHT).lstrip()
|
||||
elif environment_type == EnvironmentType.COMPUTER:
|
||||
return dedent(UITARS_COMPUTER_PREFIX_PROMPT + UITARS_USR_COMPUTER_PROMPT_THOUGHT).lstrip()
|
||||
else:
|
||||
raise ValueError(f"Expected environment type: Computer or Browser. Got {environment_type}.")
|
||||
|
||||
def get_tools(self, environment_type: EnvironmentType) -> list[dict]:
|
||||
"""Get tools for the grounding LLM, in OpenAI API tool format"""
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@@ -163,182 +367,32 @@ class GroundingAgent:
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "goto",
|
||||
"description": "Navigate to a specific URL.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"url": {"type": "string", "description": "Fully qualified URL"}},
|
||||
"required": ["url"],
|
||||
]
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
tools += [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "goto",
|
||||
"description": "Navigate to a specific URL.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"url": {"type": "string", "description": "Fully qualified URL"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "back",
|
||||
"description": "navigate back to the previous page.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "back",
|
||||
"description": "navigate back to the previous page.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]:
|
||||
"""Call the grounding LLM to get the next action based on the current state and instruction."""
|
||||
# Format the message for the API call
|
||||
messages_for_api = self._format_message_for_api(instruction, current_state)
|
||||
try:
|
||||
grounding_response: ChatCompletion = await self.client.chat.completions.create(
|
||||
messages=messages_for_api,
|
||||
model=self.model.name,
|
||||
tools=self.action_tools,
|
||||
tool_choice="required",
|
||||
temperature=0.0, # Grounding should be precise
|
||||
max_completion_tokens=1000, # Allow for thoughts + actions
|
||||
)
|
||||
if not isinstance(grounding_response, ChatCompletion):
|
||||
raise ValueError("Grounding LLM response is not of type ChatCompletion.")
|
||||
logger.debug(f"Grounding LLM response: {grounding_response.model_dump_json()}")
|
||||
|
||||
# Parse tool calls
|
||||
grounding_message = grounding_response.choices[0].message
|
||||
rendered_response, actions = self._parse_action(grounding_message, instruction, current_state)
|
||||
|
||||
# Update usage by grounding model
|
||||
self.tracer["usage"] = get_chat_usage_metrics(
|
||||
self.model.name,
|
||||
input_tokens=grounding_response.usage.prompt_tokens,
|
||||
output_tokens=grounding_response.usage.completion_tokens,
|
||||
usage=self.tracer.get("usage"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Grounding LLM: {e}")
|
||||
rendered_response = f"**Error**: Error contacting Grounding LLM: {e}"
|
||||
actions = []
|
||||
|
||||
return rendered_response, actions
|
||||
|
||||
def _format_message_for_api(self, instruction: str, current_state: EnvState) -> List:
|
||||
"""Format the message for the API call."""
|
||||
grounding_user_prompt = f"""
|
||||
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to decide the next action to complete the task.
|
||||
You control a single tab in a Chromium browser. You cannot access the OS, filesystem or the application window.
|
||||
Always use the `goto` function to navigate to a specific URL. Ctrl+t, Ctrl+w, Ctrl+q, Ctrl+Shift+T, Ctrl+Shift+W are not allowed.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes.
|
||||
goto(url='xxx') # Always use this to navigate to a specific URL. Use escape characters \\', \\", and \\n in url part to ensure we can parse the url in normal python string format.
|
||||
back() # Use this to go back to the previous page.
|
||||
|
||||
## Note
|
||||
- Use English in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
""".lstrip()
|
||||
|
||||
# Construct grounding LLM input (using only the latest user prompt + image)
|
||||
# We don't pass the full history here, as grounding depends on the *current* state + NL action
|
||||
screenshots = [f"data:image/webp;base64,{current_state.screenshot}"]
|
||||
grounding_messages_content = construct_structured_message(
|
||||
grounding_user_prompt, screenshots, self.model.name, vision_enabled=True
|
||||
)
|
||||
return [{"role": "user", "content": grounding_messages_content}]
|
||||
|
||||
def _parse_action(
|
||||
self, grounding_message: ChatCompletionMessage, instruction: str, current_state: EnvState
|
||||
) -> tuple[str, list[OperatorAction]]:
|
||||
"""Parse the tool calls from the grounding LLM response and convert them to action objects."""
|
||||
actions: List[OperatorAction] = []
|
||||
action_results: List[dict] = []
|
||||
|
||||
if grounding_message.tool_calls:
|
||||
rendered_parts = []
|
||||
for tool_call in grounding_message.tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
try:
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
action_to_run: Optional[OperatorAction] = None
|
||||
action_render_str = f"**Action ({function_name})**: {tool_call.function.arguments}"
|
||||
|
||||
if function_name == "click":
|
||||
action_to_run = ClickAction(**arguments)
|
||||
elif function_name == "left_double":
|
||||
action_to_run = DoubleClickAction(**arguments)
|
||||
elif function_name == "right_single":
|
||||
action_to_run = ClickAction(button="right", **arguments)
|
||||
elif function_name == "type":
|
||||
content = arguments.get("content")
|
||||
action_to_run = TypeAction(text=content)
|
||||
elif function_name == "scroll":
|
||||
direction = arguments.get("direction", "down")
|
||||
amount = 3
|
||||
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, **arguments)
|
||||
elif function_name == "hotkey":
|
||||
action_to_run = KeypressAction(**arguments)
|
||||
elif function_name == "goto":
|
||||
action_to_run = GotoAction(**arguments)
|
||||
elif function_name == "back":
|
||||
action_to_run = BackAction(**arguments)
|
||||
elif function_name == "wait":
|
||||
action_to_run = WaitAction(**arguments)
|
||||
elif function_name == "screenshot":
|
||||
action_to_run = ScreenshotAction(**arguments)
|
||||
elif function_name == "drag":
|
||||
# Need to convert list of dicts to list of Point objects
|
||||
path_dicts = arguments.get("path", [])
|
||||
path_points = [Point(**p) for p in path_dicts]
|
||||
if path_points:
|
||||
action_to_run = DragAction(path=path_points)
|
||||
else:
|
||||
logger.warning(f"Drag action called with empty path: {arguments}")
|
||||
action_render_str += " [Skipped - empty path]"
|
||||
elif function_name == "finished":
|
||||
action_to_run = None
|
||||
else:
|
||||
logger.warning(f"Grounding LLM called unhandled tool: {function_name}")
|
||||
action_render_str += " [Unhandled]"
|
||||
|
||||
if action_to_run:
|
||||
actions.append(action_to_run)
|
||||
action_results.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": None, # Updated after environment step
|
||||
}
|
||||
)
|
||||
rendered_parts.append(action_render_str)
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as arg_err:
|
||||
logger.error(
|
||||
f"Error parsing arguments for tool {function_name}: {arg_err} - Args: {tool_call.function.arguments}"
|
||||
)
|
||||
rendered_parts.append(f"**Error**: Failed to parse arguments for {function_name}")
|
||||
rendered_response = "\n- ".join(rendered_parts)
|
||||
else:
|
||||
# Grounding LLM responded but didn't call a tool
|
||||
logger.warning("Grounding LLM did not produce a tool call.")
|
||||
rendered_response = f"{grounding_message.content or 'No action required.'}"
|
||||
|
||||
# Render the response
|
||||
return rendered_response, actions
|
||||
return tools
|
||||
|
||||
def reset(self):
|
||||
"""Reset the agent state."""
|
||||
|
||||
@@ -10,6 +10,7 @@ import logging
|
||||
import math
|
||||
import re
|
||||
from io import BytesIO
|
||||
from textwrap import dedent
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
@@ -18,7 +19,7 @@ from openai.types.chat import ChatCompletion
|
||||
from PIL import Image
|
||||
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_environment_base import EnvState
|
||||
from khoj.processor.operator.operator_environment_base import EnvironmentType, EnvState
|
||||
from khoj.utils.helpers import get_chat_usage_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,29 +36,8 @@ class GroundingAgentUitars:
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
UITARS_USR_PROMPT_THOUGHT = """
|
||||
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to perform the next action to complete the task.
|
||||
You control a single tab in a Chromium browser. You cannot access the OS, filesystem, the application window or the addressbar.
|
||||
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
{action_space}
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
UITARS_NORMAL_ACTION_SPACE = """
|
||||
UITARS_NORMAL_ACTION_SPACE = dedent(
|
||||
"""
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
@@ -67,14 +47,15 @@ class GroundingAgentUitars:
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
""".lstrip()
|
||||
"""
|
||||
).lstrip()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
environment_type: EnvironmentType,
|
||||
client: AsyncOpenAI | AsyncAzureOpenAI,
|
||||
max_iterations=50,
|
||||
environment_type: Literal["computer", "web"] = "computer",
|
||||
runtime_conf: dict = {
|
||||
"infer_mode": "qwen25vl_normal",
|
||||
"prompt_style": "qwen25vl_normal",
|
||||
@@ -94,7 +75,7 @@ class GroundingAgentUitars:
|
||||
self.model_name = model_name
|
||||
self.client = client
|
||||
self.tracer = tracer
|
||||
self.environment_type = environment_type
|
||||
self.environment = environment_type
|
||||
|
||||
self.max_iterations = max_iterations
|
||||
self.runtime_conf = runtime_conf
|
||||
@@ -116,7 +97,7 @@ class GroundingAgentUitars:
|
||||
self.history_images: list[bytes] = []
|
||||
self.history_responses: list[str] = []
|
||||
|
||||
self.prompt_template = self.UITARS_USR_PROMPT_THOUGHT
|
||||
self.prompt_template = self.get_instruction(self.environment)
|
||||
self.prompt_action_space = self.UITARS_NORMAL_ACTION_SPACE
|
||||
|
||||
if "history_n" in self.runtime_conf:
|
||||
@@ -126,11 +107,11 @@ class GroundingAgentUitars:
|
||||
|
||||
self.cur_callusr_count = 0
|
||||
|
||||
async def act(self, instruction: str, env_state: EnvState) -> tuple[str, list[OperatorAction]]:
|
||||
async def act(self, instruction: str, current_state: EnvState) -> tuple[str, list[OperatorAction]]:
|
||||
"""
|
||||
Suggest the next action(s) based on the instruction and current environment.
|
||||
"""
|
||||
messages = self._format_messages_for_api(instruction, env_state)
|
||||
messages = self._format_messages_for_api(instruction, current_state)
|
||||
|
||||
recent_screenshot = Image.open(BytesIO(self.history_images[-1]))
|
||||
origin_resized_height = recent_screenshot.height
|
||||
@@ -145,9 +126,11 @@ class GroundingAgentUitars:
|
||||
try_times = 3
|
||||
while not parsed_responses:
|
||||
if try_times <= 0:
|
||||
print(f"Reach max retry times to fetch response from client, as error flag.")
|
||||
logger.warning(f"Reach max retry times to fetch response from client, as error flag.")
|
||||
return "client error\nFAIL", []
|
||||
try:
|
||||
message_content = "\n".join([msg["content"][0].get("text") or "[image]" for msg in messages])
|
||||
logger.debug(f"User message content: {message_content}")
|
||||
response: ChatCompletion = await self.client.chat.completions.create(
|
||||
model="ui-tars",
|
||||
messages=messages,
|
||||
@@ -228,20 +211,9 @@ class GroundingAgentUitars:
|
||||
self.actions.append(actions)
|
||||
return f"{prediction}\nFAIL", []
|
||||
|
||||
if self.environment_type == "web":
|
||||
actions.extend(
|
||||
self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap)
|
||||
)
|
||||
else:
|
||||
pass
|
||||
# TODO: Add PyautoguiAction when enable computer environment
|
||||
# actions.append(
|
||||
# PyautoguiAction(code=
|
||||
# self.parsing_response_to_pyautogui_code(
|
||||
# parsed_response, obs_image_height, obs_image_width, self.input_swap
|
||||
# )
|
||||
# )
|
||||
# )
|
||||
actions.extend(
|
||||
self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap)
|
||||
)
|
||||
|
||||
self.actions.append(actions)
|
||||
|
||||
@@ -252,13 +224,52 @@ class GroundingAgentUitars:
|
||||
|
||||
return prediction or "", actions
|
||||
|
||||
def _format_messages_for_api(self, instruction: str, env_state: EnvState):
|
||||
def get_instruction(self, environment_type: EnvironmentType) -> str:
|
||||
"""
|
||||
Get the instruction for the agent based on the environment type.
|
||||
"""
|
||||
UITARS_COMPUTER_PREFIX_PROMPT = """
|
||||
You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
"""
|
||||
UITARS_BROWSER_PREFIX_PROMPT = """
|
||||
You are a GUI agent. You are given a task and a screenshot of the web browser tab you operate. You need to perform the next action to complete the task.
|
||||
You control a single tab in a Chromium browser. You cannot access the OS, filesystem, the application window or the addressbar.
|
||||
"""
|
||||
|
||||
UITARS_USR_PROMPT_THOUGHT = """
|
||||
Try fulfill the user instruction to the best of your ability, especially when the instruction is given multiple times. Do not ignore the instruction.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
{action_space}
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
return dedent(UITARS_BROWSER_PREFIX_PROMPT + UITARS_USR_PROMPT_THOUGHT).lstrip()
|
||||
elif environment_type == EnvironmentType.COMPUTER:
|
||||
return dedent(UITARS_COMPUTER_PREFIX_PROMPT + UITARS_USR_PROMPT_THOUGHT).lstrip()
|
||||
else:
|
||||
raise ValueError(f"Unsupported environment type: {environment_type}")
|
||||
|
||||
def _format_messages_for_api(self, instruction: str, current_state: EnvState):
|
||||
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
|
||||
self.thoughts
|
||||
), "The number of observations and actions should be the same."
|
||||
|
||||
self.history_images.append(base64.b64decode(env_state.screenshot))
|
||||
self.observations.append({"screenshot": env_state.screenshot, "accessibility_tree": None})
|
||||
self.history_images.append(base64.b64decode(current_state.screenshot))
|
||||
self.observations.append({"screenshot": current_state.screenshot, "accessibility_tree": None})
|
||||
|
||||
user_prompt = self.prompt_template.format(
|
||||
instruction=instruction, action_space=self.prompt_action_space, language=self.language
|
||||
|
||||
@@ -3,7 +3,8 @@ import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, cast
|
||||
from textwrap import dedent
|
||||
from typing import List, Literal, Optional, cast
|
||||
|
||||
from anthropic.types.beta import BetaContentBlock
|
||||
|
||||
@@ -14,7 +15,11 @@ from khoj.processor.operator.operator_agent_base import (
|
||||
AgentMessage,
|
||||
OperatorAgent,
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
EnvState,
|
||||
EnvStepResult,
|
||||
)
|
||||
from khoj.utils.helpers import get_anthropic_async_client, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -32,51 +37,12 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
action_results: List[dict] = []
|
||||
self._commit_trace() # Commit trace before next action
|
||||
|
||||
system_prompt = f"""<SYSTEM_CAPABILITY>
|
||||
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
|
||||
* You operate a Chromium browser using Playwright via the 'computer' tool.
|
||||
* You cannot access the OS or filesystem.
|
||||
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more.
|
||||
* You can use the additional back() and goto() helper functions to ease navigating the browser. If you see nothing, try goto duckduckgo.com
|
||||
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current URL is {current_state.url}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
system_prompt = self.get_instructions(self.environment_type, current_state)
|
||||
tools = self.get_tools(self.environment_type, current_state)
|
||||
|
||||
<IMPORTANT>
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* Do not loop on wait, screenshot for too many turns without taking any action.
|
||||
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
|
||||
</IMPORTANT>
|
||||
"""
|
||||
if is_none_or_empty(self.messages):
|
||||
self.messages = [AgentMessage(role="user", content=self.query)]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": self.model_default_tool("computer"),
|
||||
"name": "computer",
|
||||
"display_width_px": 1024,
|
||||
"display_height_px": 768,
|
||||
}, # TODO: Get from env
|
||||
{
|
||||
"name": "back",
|
||||
"description": "Go back to the previous page.",
|
||||
"input_schema": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "goto",
|
||||
"description": "Go to a specific URL.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"url": {"type": "string", "description": "Fully qualified URL to navigate to."}},
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
thinking: dict[str, str | int] = {"type": "disabled"}
|
||||
if is_reasoning_model(self.vision_model.name):
|
||||
thinking = {"type": "enabled", "budget_tokens": 1024}
|
||||
@@ -400,3 +366,80 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
return ["computer-use-2025-01-24"]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_instructions(self, environment_type: EnvironmentType, current_state: EnvState) -> str:
|
||||
"""Return system instructions for the Anthropic operator."""
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
return dedent(
|
||||
f"""
|
||||
<SYSTEM_CAPABILITY>
|
||||
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
|
||||
* You operate a Chromium browser using Playwright via the 'computer' tool.
|
||||
* You cannot access the OS or filesystem.
|
||||
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more.
|
||||
* You can use the additional back() and goto() helper functions to ease navigating the browser. If you see nothing, try goto duckduckgo.com
|
||||
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current URL is {current_state.url}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<IMPORTANT>
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* Do not loop on wait, screenshot for too many turns without taking any action.
|
||||
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
|
||||
</IMPORTANT>
|
||||
"""
|
||||
).lstrip()
|
||||
elif environment_type == EnvironmentType.COMPUTER:
|
||||
return dedent(
|
||||
f"""
|
||||
<SYSTEM_CAPABILITY>
|
||||
* You are Khoj, a smart computer operating assistant. You help the users accomplish tasks using a computer.
|
||||
* You can interact with the computer to perform tasks like clicking, typing, scrolling, and more.
|
||||
* When viewing a document or webpage it can be helpful to zoom out or scroll down to ensure you see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* Do not loop on wait, screenshot for too many turns without taking any action.
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<CONTEXT>
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
</CONTEXT>
|
||||
"""
|
||||
).lstrip()
|
||||
else:
|
||||
raise ValueError(f"Unsupported environment type for Anthropic operator: {environment_type}")
|
||||
|
||||
def get_tools(self, environment: EnvironmentType, current_state: EnvState) -> list[dict]:
|
||||
"""Return the tools available for the Anthropic operator."""
|
||||
tools = [
|
||||
{
|
||||
"type": self.model_default_tool("computer"),
|
||||
"name": "computer",
|
||||
"display_width_px": current_state.width,
|
||||
"display_height_px": current_state.height,
|
||||
}
|
||||
]
|
||||
|
||||
if environment == "browser":
|
||||
tools += [
|
||||
{
|
||||
"name": "back",
|
||||
"description": "Go back to the previous page.",
|
||||
"input_schema": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "goto",
|
||||
"description": "Go to a specific URL.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"url": {"type": "string", "description": "Fully qualified URL to navigate to."}},
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
return tools
|
||||
|
||||
@@ -7,7 +7,11 @@ from pydantic import BaseModel
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import commit_conversation_trace
|
||||
from khoj.processor.operator.operator_actions import OperatorAction
|
||||
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
EnvState,
|
||||
EnvStepResult,
|
||||
)
|
||||
from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,9 +29,12 @@ class AgentMessage(BaseModel):
|
||||
|
||||
|
||||
class OperatorAgent(ABC):
|
||||
def __init__(self, query: str, vision_model: ChatModel, max_iterations: int, tracer: dict):
|
||||
def __init__(
|
||||
self, query: str, vision_model: ChatModel, environment_type: EnvironmentType, max_iterations: int, tracer: dict
|
||||
):
|
||||
self.query = query
|
||||
self.vision_model = vision_model
|
||||
self.environment_type = environment_type
|
||||
self.max_iterations = max_iterations
|
||||
self.tracer = tracer
|
||||
self.messages: List[AgentMessage] = []
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat import ChatCompletion
|
||||
from textwrap import dedent
|
||||
from typing import List
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import construct_structured_message
|
||||
@@ -15,7 +14,11 @@ from khoj.processor.operator.operator_agent_base import (
|
||||
AgentMessage,
|
||||
OperatorAgent,
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
EnvState,
|
||||
EnvStepResult,
|
||||
)
|
||||
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||
from khoj.utils.helpers import get_openai_async_client, is_none_or_empty
|
||||
|
||||
@@ -27,7 +30,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
"""
|
||||
An OperatorAgent that uses two LLMs:
|
||||
1. Reasoning LLM: Determines the next high-level action based on the objective and current visual reasoning trajectory.
|
||||
2. Grounding LLM: Converts the high-level action into specific, executable browser actions.
|
||||
2. Grounding LLM: Converts the high-level action into specific, actions executable on the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -35,10 +38,13 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
query: str,
|
||||
reasoning_model: ChatModel,
|
||||
grounding_model: ChatModel,
|
||||
environment_type: EnvironmentType,
|
||||
max_iterations: int,
|
||||
tracer: dict,
|
||||
):
|
||||
super().__init__(query, reasoning_model, max_iterations, tracer) # Use reasoning model for primary tracking
|
||||
super().__init__(
|
||||
query, reasoning_model, environment_type, max_iterations, tracer
|
||||
) # Use reasoning model for primary tracking
|
||||
self.reasoning_model = reasoning_model
|
||||
self.grounding_model = grounding_model
|
||||
# Initialize openai api compatible client for grounding model
|
||||
@@ -49,10 +55,12 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
self.grounding_agent: GroundingAgent | GroundingAgentUitars = None
|
||||
if "ui-tars-1.5" in grounding_model.name:
|
||||
self.grounding_agent = GroundingAgentUitars(
|
||||
grounding_model.name, grounding_client, max_iterations, environment_type="web", tracer=tracer
|
||||
grounding_model.name, self.environment_type, grounding_client, max_iterations, tracer=tracer
|
||||
)
|
||||
else:
|
||||
self.grounding_agent = GroundingAgent(grounding_model.name, grounding_client, max_iterations, tracer=tracer)
|
||||
self.grounding_agent = GroundingAgent(
|
||||
grounding_model.name, self.environment_type, grounding_client, max_iterations, tracer=tracer
|
||||
)
|
||||
|
||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||
"""
|
||||
@@ -84,48 +92,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
"""
|
||||
Uses the reasoning LLM to determine the next high-level action based on the operation trajectory.
|
||||
"""
|
||||
reasoning_system_prompt = f"""
|
||||
# Introduction
|
||||
* You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser.
|
||||
* You are given the user's query and screenshots of the browser's state transitions.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current URL is {current_state.url}.
|
||||
|
||||
# Your Task
|
||||
* First look at the screenshots carefully to notice all pertinent information.
|
||||
* Then instruct a tool AI to perform the next action that will help you progress towards the user's goal.
|
||||
* Make sure you scroll down to see everything before deciding something isn't available.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* Use your creativity to find alternate ways to make progress if you get stuck at any point.
|
||||
|
||||
# Tool AI Capabilities
|
||||
* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform the next action on the page.
|
||||
* It can interact with the web browser with these actions: click, right click, double click, type, scroll, drag, wait, goto url and go back to previous page.
|
||||
* It cannot access the OS, filesystem or application window. It just controls a single Chromium browser tab via Playwright.
|
||||
|
||||
# IMPORTANT
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* To navigate to a specific URL, put "GOTO <URL>" (without quotes) on the last line of your response.
|
||||
* To navigate back to the previous page, end your response with "BACK" (without quotes).
|
||||
* Once you've verified that the main objective has been achieved, end your response with "DONE" (without quotes).
|
||||
|
||||
# Examples
|
||||
## Example 1
|
||||
GOTO https://example.com
|
||||
## Example 2
|
||||
click the blue login button located at the top right corner
|
||||
## Example 3
|
||||
scroll down the page
|
||||
## Example 4
|
||||
type the username example@email.com into the input field labeled Username
|
||||
## Example 5
|
||||
DONE
|
||||
|
||||
# Instructions
|
||||
Now describe a single high-level action to take next to progress towards the user's goal in detail.
|
||||
Focus on the visual action and provide all necessary context.
|
||||
""".strip()
|
||||
|
||||
reasoning_system_prompt = self.get_instruction(self.environment_type, current_state)
|
||||
if is_none_or_empty(self.messages):
|
||||
query_text = f"**Main Objective**: {self.query}"
|
||||
query_screenshot = [f"data:image/webp;base64,{current_state.screenshot}"]
|
||||
@@ -330,6 +297,96 @@ Focus on the visual action and provide all necessary context.
|
||||
]
|
||||
return formatted_messages
|
||||
|
||||
def get_instruction(self, environment_type: EnvironmentType, env_state: EnvState) -> str:
|
||||
"""Get the system instruction for the reasoning agent."""
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
return dedent(
|
||||
f"""
|
||||
# Introduction
|
||||
* You are Khoj, a smart and resourceful web browsing assistant. You help the user accomplish their task using a web browser.
|
||||
* You are given the user's query and screenshots of the browser's state transitions.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current URL is {env_state.url}.
|
||||
|
||||
# Your Task
|
||||
* First look at the screenshots carefully to notice all pertinent information.
|
||||
* Then instruct a tool AI to perform the next action that will help you progress towards the user's goal.
|
||||
* Make sure you scroll down to see everything before deciding something isn't available.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* Use your creativity to find alternate ways to make progress if you get stuck at any point.
|
||||
|
||||
# Tool AI Capabilities
|
||||
* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform the next action on the page.
|
||||
* It can interact with the web browser with these actions: click, right click, double click, type, scroll, drag, wait, goto url and go back to previous page.
|
||||
* It cannot access the OS, filesystem or application window. It just controls a single Chromium browser tab via Playwright.
|
||||
|
||||
# IMPORTANT
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* To navigate to a specific URL, put "GOTO <URL>" (without quotes) on the last line of your response.
|
||||
* To navigate back to the previous page, end your response with "BACK" (without quotes).
|
||||
* Once you've verified that the main objective has been achieved, end your response with "DONE" (without quotes).
|
||||
|
||||
# Examples
|
||||
## Example 1
|
||||
GOTO https://example.com
|
||||
## Example 2
|
||||
click the blue login button located at the top right corner
|
||||
## Example 3
|
||||
scroll down the page
|
||||
## Example 4
|
||||
type the username example@email.com into the input field labeled Username
|
||||
## Example 5
|
||||
DONE
|
||||
|
||||
# Instructions
|
||||
Now describe a single high-level action to take next to progress towards the user's goal in detail.
|
||||
Focus on the visual action and provide all necessary context.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
elif environment_type == EnvironmentType.COMPUTER:
|
||||
return dedent(
|
||||
f"""
|
||||
# Introduction
|
||||
* You are Khoj, a smart and resourceful computer assistant. You help the user accomplish their task using a computer.
|
||||
* You are given the user's query and screenshots of the computer's state transitions.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
|
||||
# Your Task
|
||||
* First look at the screenshots carefully to notice all pertinent information.
|
||||
* Then instruct a tool AI to perform the next action that will help you progress towards the user's goal.
|
||||
* Make sure you scroll down to see everything before deciding something isn't available.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* Use your creativity to find alternate ways to make progress if you get stuck at any point.
|
||||
|
||||
# Tool AI Capabilities
|
||||
* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform the next action on the page.
|
||||
* It can interact with the computer with these actions: click, right click, double click, type, scroll, drag, wait to previous page.
|
||||
|
||||
# IMPORTANT
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* Once you've verified that the main objective has been achieved, end your response with "DONE" (without quotes).
|
||||
|
||||
# Examples
|
||||
## Example 1
|
||||
type https://example.com into the address bar and press Enter
|
||||
## Example 2
|
||||
click the blue login button located at the top right corner
|
||||
## Example 3
|
||||
scroll down the page
|
||||
## Example 4
|
||||
type the username example@email.com into the input field labeled Username
|
||||
## Example 5
|
||||
DONE
|
||||
|
||||
# Instructions
|
||||
Now describe a single high-level action to take next to progress towards the user's goal in detail.
|
||||
Focus on the visual action and provide all necessary context.
|
||||
"""
|
||||
).strip()
|
||||
else:
|
||||
raise ValueError(f"Expected environment type: Computer or Browser. Got {environment_type}.")
|
||||
|
||||
def reset(self):
|
||||
"""Reset the agent state."""
|
||||
super().reset()
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
import platform
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from textwrap import dedent
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from openai.types.responses import Response, ResponseOutputItem
|
||||
@@ -12,7 +14,11 @@ from khoj.processor.operator.operator_agent_base import (
|
||||
AgentMessage,
|
||||
OperatorAgent,
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
|
||||
from khoj.processor.operator.operator_environment_base import (
|
||||
EnvironmentType,
|
||||
EnvState,
|
||||
EnvStepResult,
|
||||
)
|
||||
from khoj.utils.helpers import get_openai_async_client, is_none_or_empty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -29,55 +35,8 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
actions: List[OperatorAction] = []
|
||||
action_results: List[dict] = []
|
||||
self._commit_trace() # Commit trace before next action
|
||||
system_prompt = f"""<SYSTEM_CAPABILITY>
|
||||
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
|
||||
* You operate a single Chromium browser page using Playwright.
|
||||
* You cannot access the OS or filesystem.
|
||||
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more using the computer_use_preview tool.
|
||||
* You can use the additional back() and goto() functions to navigate the browser.
|
||||
* Always use the goto() function to navigate to a specific URL. If you see nothing, try goto duckduckgo.com
|
||||
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current URL is {current_state.url}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<IMPORTANT>
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
|
||||
</IMPORTANT>
|
||||
"""
|
||||
tools = [
|
||||
{
|
||||
"type": "computer_use_preview",
|
||||
"display_width": 1024, # TODO: Get from env
|
||||
"display_height": 768, # TODO: Get from env
|
||||
"environment": "browser",
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "back",
|
||||
"description": "Go back to the previous page.",
|
||||
"parameters": {},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "goto",
|
||||
"description": "Go to a specific URL.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Fully qualified URL to navigate to.",
|
||||
},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
]
|
||||
system_prompt = self.get_instructions(self.environment_type, current_state)
|
||||
tools = self.get_tools(self.environment_type, current_state)
|
||||
|
||||
if is_none_or_empty(self.messages):
|
||||
self.messages = [AgentMessage(role="user", content=self.query)]
|
||||
@@ -347,3 +306,93 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
}
|
||||
|
||||
return render_payload
|
||||
|
||||
def get_instructions(self, environment_type: EnvironmentType, current_state: EnvState) -> str:
|
||||
"""Return system instructions for the OpenAI operator."""
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
return dedent(
|
||||
f"""
|
||||
<SYSTEM_CAPABILITY>
|
||||
* You are Khoj, a smart web browser operating assistant. You help the users accomplish tasks using a web browser.
|
||||
* You operate a single Chromium browser page using Playwright.
|
||||
* You cannot access the OS or filesystem.
|
||||
* You can interact with the web browser to perform tasks like clicking, typing, scrolling, and more using the computer_use_preview tool.
|
||||
* You can use the additional back() and goto() functions to navigate the browser.
|
||||
* Always use the goto() function to navigate to a specific URL. If you see nothing, try goto duckduckgo.com
|
||||
* When viewing a webpage it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
* The current URL is {current_state.url}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<IMPORTANT>
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
* After initialization if the browser is blank, enter a website URL using the goto() function instead of waiting
|
||||
</IMPORTANT>
|
||||
"""
|
||||
).lstrip()
|
||||
elif environment_type == EnvironmentType.COMPUTER:
|
||||
return dedent(
|
||||
f"""
|
||||
<SYSTEM_CAPABILITY>
|
||||
* You are Khoj, a smart computer operating assistant. You help the users accomplish their tasks using a computer.
|
||||
* You can interact with the computer to perform tasks like clicking, typing, scrolling, and more using the computer_use_preview tool.
|
||||
* When viewing a document or webpage it can be helpful to zoom out or scroll down to ensure you see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail.
|
||||
* You are allowed upto {self.max_iterations} iterations to complete the task.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<CONTEXT>
|
||||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
|
||||
</CONTEXT>
|
||||
"""
|
||||
).lstrip()
|
||||
else:
|
||||
raise ValueError(f"Unsupported environment type: {environment_type}")
|
||||
|
||||
def get_tools(self, environment_type: EnvironmentType, current_state: EnvState) -> list[dict]:
|
||||
"""Return the tools available for the OpenAI operator."""
|
||||
if environment_type == EnvironmentType.COMPUTER:
|
||||
# get os info of this computer. it can be mac, windows, linux
|
||||
environment_os = (
|
||||
"mac" if platform.system() == "Darwin" else "windows" if platform.system() == "Windows" else "linux"
|
||||
)
|
||||
else:
|
||||
environment_os = "browser"
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "computer_use_preview",
|
||||
"display_width": current_state.width,
|
||||
"display_height": current_state.height,
|
||||
"environment": environment_os,
|
||||
}
|
||||
]
|
||||
if environment_type == EnvironmentType.BROWSER:
|
||||
tools += [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "back",
|
||||
"description": "Go back to the previous page.",
|
||||
"parameters": {},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "goto",
|
||||
"description": "Go to a specific URL.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Fully qualified URL to navigate to.",
|
||||
},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
]
|
||||
return tools
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -6,9 +7,18 @@ from pydantic import BaseModel
|
||||
from khoj.processor.operator.operator_actions import OperatorAction
|
||||
|
||||
|
||||
class EnvironmentType(Enum):
|
||||
"""Type of environment to operate."""
|
||||
|
||||
COMPUTER = "computer"
|
||||
BROWSER = "browser"
|
||||
|
||||
|
||||
class EnvState(BaseModel):
|
||||
url: str
|
||||
height: int
|
||||
width: int
|
||||
screenshot: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class EnvStepResult(BaseModel):
|
||||
|
||||
@@ -124,10 +124,10 @@ class BrowserEnvironment(Environment):
|
||||
|
||||
async def get_state(self) -> EnvState:
|
||||
if not self.page or self.page.is_closed():
|
||||
return EnvState(url="about:blank", screenshot=None)
|
||||
return EnvState(url="about:blank", screenshot=None, height=self.height, width=self.width)
|
||||
url = self.page.url
|
||||
screenshot = await self._get_screenshot()
|
||||
return EnvState(url=url, screenshot=screenshot)
|
||||
return EnvState(url=url, screenshot=screenshot, height=self.height, width=self.width)
|
||||
|
||||
async def step(self, action: OperatorAction) -> EnvStepResult:
|
||||
if not self.page or self.page.is_closed():
|
||||
|
||||
@@ -31,7 +31,7 @@ from khoj.processor.conversation.utils import (
|
||||
save_to_conversation_log,
|
||||
)
|
||||
from khoj.processor.image.generate import text_to_image
|
||||
from khoj.processor.operator.operate_browser import operate_browser
|
||||
from khoj.processor.operator import operate_environment
|
||||
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
||||
from khoj.processor.tools.online_search import (
|
||||
deduplicate_organic_results,
|
||||
@@ -1292,7 +1292,7 @@ async def chat(
|
||||
)
|
||||
if ConversationCommand.Operator in conversation_commands:
|
||||
try:
|
||||
async for result in operate_browser(
|
||||
async for result in operate_environment(
|
||||
defiltered_query,
|
||||
user,
|
||||
meta_log,
|
||||
|
||||
@@ -18,7 +18,7 @@ from khoj.processor.conversation.utils import (
|
||||
construct_tool_chat_history,
|
||||
load_complex_json,
|
||||
)
|
||||
from khoj.processor.operator.operate_browser import operate_browser
|
||||
from khoj.processor.operator import operate_environment
|
||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.processor.tools.run_code import run_code
|
||||
from khoj.routers.api import extract_references_and_questions
|
||||
@@ -417,12 +417,12 @@ async def execute_information_collection(
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Operator:
|
||||
try:
|
||||
async for result in operate_browser(
|
||||
async for result in operate_environment(
|
||||
this_iteration.query,
|
||||
user,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Operator),
|
||||
location,
|
||||
send_status_func,
|
||||
send_status_func=send_status_func,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
query_files=query_files,
|
||||
|
||||
Reference in New Issue
Block a user