Resolve mypy typing errors in operator code

This commit is contained in:
Debanjum
2025-05-09 19:51:57 -06:00
parent 33689feb91
commit 95f211d03c
11 changed files with 92 additions and 82 deletions

View File

@@ -144,7 +144,7 @@ async def converse_anthropic(
user_query, user_query,
online_results: Optional[Dict[str, Dict]] = None, online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[Dict[str, Dict]] = None, operator_results: Optional[List[str]] = None,
conversation_log={}, conversation_log={},
model: Optional[str] = "claude-3-7-sonnet-latest", model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None, api_key: Optional[str] = None,

View File

@@ -166,7 +166,7 @@ async def converse_gemini(
user_query, user_query,
online_results: Optional[Dict[str, Dict]] = None, online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[Dict[str, Dict]] = None, operator_results: Optional[List[str]] = None,
conversation_log={}, conversation_log={},
model: Optional[str] = "gemini-2.0-flash", model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None, api_key: Optional[str] = None,

View File

@@ -169,7 +169,7 @@ async def converse_openai(
user_query, user_query,
online_results: Optional[Dict[str, Dict]] = None, online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[Dict[str, Dict]] = None, operator_results: Optional[List[str]] = None,
conversation_log={}, conversation_log={},
model: str = "gpt-4o-mini", model: str = "gpt-4o-mini",
api_key: Optional[str] = None, api_key: Optional[str] = None,

View File

@@ -210,7 +210,7 @@ class GroundingAgent:
self.tracer["usage"] = get_chat_usage_metrics( self.tracer["usage"] = get_chat_usage_metrics(
self.model.name, self.model.name,
input_tokens=grounding_response.usage.prompt_tokens, input_tokens=grounding_response.usage.prompt_tokens,
completion_tokens=grounding_response.usage.completion_tokens, output_tokens=grounding_response.usage.completion_tokens,
usage=self.tracer.get("usage"), usage=self.tracer.get("usage"),
) )
except Exception as e: except Exception as e:

View File

@@ -10,7 +10,7 @@ import logging
import math import math
import re import re
from io import BytesIO from io import BytesIO
from typing import List from typing import Any, List
import numpy as np import numpy as np
from openai import AzureOpenAI, OpenAI from openai import AzureOpenAI, OpenAI
@@ -112,11 +112,11 @@ class GroundingAgentUitars:
self.min_pixels = self.runtime_conf["min_pixels"] self.min_pixels = self.runtime_conf["min_pixels"]
self.callusr_tolerance = self.runtime_conf["callusr_tolerance"] self.callusr_tolerance = self.runtime_conf["callusr_tolerance"]
self.thoughts = [] self.thoughts: list[str] = []
self.actions = [] self.actions: list[list[OperatorAction]] = []
self.observations = [] self.observations: list[dict] = []
self.history_images = [] self.history_images: list[bytes] = []
self.history_responses = [] self.history_responses: list[str] = []
self.prompt_template = self.UITARS_USR_PROMPT_THOUGHT self.prompt_template = self.UITARS_USR_PROMPT_THOUGHT
self.prompt_action_space = self.UITARS_NORMAL_ACTION_SPACE self.prompt_action_space = self.UITARS_NORMAL_ACTION_SPACE
@@ -159,7 +159,7 @@ class GroundingAgentUitars:
# top_k=top_k, # top_k=top_k,
top_p=self.top_p, top_p=self.top_p,
) )
prediction: str = response.choices[0].message.content.strip() prediction = response.choices[0].message.content.strip()
self.tracer["usage"] = get_chat_usage_metrics( self.tracer["usage"] = get_chat_usage_metrics(
self.model_name, self.model_name,
input_tokens=response.usage.prompt_tokens, input_tokens=response.usage.prompt_tokens,
@@ -235,11 +235,15 @@ class GroundingAgentUitars:
self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap) self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap)
) )
else: else:
actions.append( pass
self.parsing_response_to_pyautogui_code( # TODO: Add PyautoguiAction when enable computer environment
parsed_response, obs_image_height, obs_image_width, self.input_swap # actions.append(
) # PyautoguiAction(code=
) # self.parsing_response_to_pyautogui_code(
# parsed_response, obs_image_height, obs_image_width, self.input_swap
# )
# )
# )
self.actions.append(actions) self.actions.append(actions)
@@ -268,7 +272,8 @@ class GroundingAgentUitars:
if len(self.history_images) > self.history_n: if len(self.history_images) > self.history_n:
self.history_images = self.history_images[-self.history_n :] self.history_images = self.history_images[-self.history_n :]
messages, images = [], [] messages: list[dict] = []
images: list[Any] = []
if isinstance(self.history_images, bytes): if isinstance(self.history_images, bytes):
self.history_images = [self.history_images] self.history_images = [self.history_images]
elif isinstance(self.history_images, np.ndarray): elif isinstance(self.history_images, np.ndarray):
@@ -414,11 +419,11 @@ class GroundingAgentUitars:
"""Returns the closest integer to 'number' that is divisible by 'factor'.""" """Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor return round(number / factor) * factor
def ceil_by_factor(self, number: int, factor: int) -> int: def ceil_by_factor(self, number: float, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor return math.ceil(number / factor) * factor
def floor_by_factor(self, number: int, factor: int) -> int: def floor_by_factor(self, number: float, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor return math.floor(number / factor) * factor

View File

@@ -2,7 +2,7 @@ import json
import logging import logging
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from typing import Any, List, Optional from typing import Any, List, Optional, cast
from anthropic.types.beta import BetaContentBlock from anthropic.types.beta import BetaContentBlock
@@ -76,7 +76,7 @@ class AnthropicOperatorAgent(OperatorAgent):
}, },
] ]
thinking = {"type": "disabled"} thinking: dict[str, str | int] = {"type": "disabled"}
if self.vision_model.name.startswith("claude-3-7"): if self.vision_model.name.startswith("claude-3-7"):
thinking = {"type": "enabled", "budget_tokens": 1024} thinking = {"type": "enabled", "budget_tokens": 1024}
@@ -94,7 +94,7 @@ class AnthropicOperatorAgent(OperatorAgent):
logger.debug(f"Anthropic response: {response.model_dump_json()}") logger.debug(f"Anthropic response: {response.model_dump_json()}")
self.messages.append(AgentMessage(role="assistant", content=response.content)) self.messages.append(AgentMessage(role="assistant", content=response.content))
rendered_response = await self._render_response(response.content, current_state.screenshot) rendered_response = self._render_response(response.content, current_state.screenshot)
for block in response.content: for block in response.content:
if block.type == "tool_use": if block.type == "tool_use":
@@ -140,7 +140,7 @@ class AnthropicOperatorAgent(OperatorAgent):
elif tool_name == "left_mouse_up": elif tool_name == "left_mouse_up":
action_to_run = MouseUpAction(button="left") action_to_run = MouseUpAction(button="left")
elif tool_name == "type": elif tool_name == "type":
text = tool_input.get("text") text: str = tool_input.get("text")
if text: if text:
action_to_run = TypeAction(text=text) action_to_run = TypeAction(text=text)
elif tool_name == "scroll": elif tool_name == "scroll":
@@ -152,7 +152,7 @@ class AnthropicOperatorAgent(OperatorAgent):
if direction: if direction:
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y) action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y)
elif tool_name == "key": elif tool_name == "key":
text: str = tool_input.get("text") text = tool_input.get("text")
if text: if text:
action_to_run = KeypressAction(keys=text.split("+")) # Split xdotool style action_to_run = KeypressAction(keys=text.split("+")) # Split xdotool style
elif tool_name == "hold_key": elif tool_name == "hold_key":
@@ -214,7 +214,7 @@ class AnthropicOperatorAgent(OperatorAgent):
for idx, env_step in enumerate(env_steps): for idx, env_step in enumerate(env_steps):
action_result = agent_action.action_results[idx] action_result = agent_action.action_results[idx]
result_content = env_step.error or env_step.output or "[Action completed]" result_content = env_step.error or env_step.output or "[Action completed]"
if env_step.type == "image": if env_step.type == "image" and isinstance(result_content, dict):
# Add screenshot data in anthropic message format # Add screenshot data in anthropic message format
action_result["content"] = [ action_result["content"] = [
{ {
@@ -262,12 +262,20 @@ class AnthropicOperatorAgent(OperatorAgent):
) )
return formatted_messages return formatted_messages
def compile_response(self, response_content: list[BetaContentBlock | Any]) -> str: def compile_response(self, response_content: list[BetaContentBlock | dict] | str) -> str:
"""Compile Anthropic response into a single string.""" """Compile Anthropic response into a single string."""
if is_none_or_empty(response_content) or not all(hasattr(item, "type") for item in response_content): if isinstance(response_content, str):
return response_content return response_content
elif is_none_or_empty(response_content):
return ""
# action results are a list dictionaries,
# beta content blocks are objects with a type attribute
elif isinstance(response_content[0], dict):
return json.dumps(response_content)
compiled_response = [""] compiled_response = [""]
for block in deepcopy(response_content): for block in deepcopy(response_content):
block = cast(BetaContentBlock, block) # Ensure block is of type BetaContentBlock
if block.type == "text": if block.type == "text":
compiled_response.append(block.text) compiled_response.append(block.text)
elif block.type == "tool_use": elif block.type == "tool_use":
@@ -291,8 +299,7 @@ class AnthropicOperatorAgent(OperatorAgent):
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
@staticmethod def _render_response(self, response_content: list[BetaContentBlock], screenshot: str | None) -> dict:
async def _render_response(response_content: list[BetaContentBlock], screenshot: Optional[str] = None) -> dict:
"""Render Anthropic response, potentially including actual screenshots.""" """Render Anthropic response, potentially including actual screenshots."""
render_texts = [] render_texts = []
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
@@ -315,11 +322,11 @@ class AnthropicOperatorAgent(OperatorAgent):
elif "action" in block_input: elif "action" in block_input:
action = block_input["action"] action = block_input["action"]
if action == "type": if action == "type":
text = block_input.get("text") text: str = block_input.get("text")
if text: if text:
render_texts += [f'Type "{text}"'] render_texts += [f'Type "{text}"']
elif action == "key": elif action == "key":
text: str = block_input.get("text") text = block_input.get("text")
if text: if text:
render_texts += [f"Press {text}"] render_texts += [f"Press {text}"]
elif action == "hold_key": elif action == "hold_key":

View File

@@ -50,11 +50,11 @@ class OperatorAgent(ABC):
return self.compile_response(self.messages[-1].content) return self.compile_response(self.messages[-1].content)
@abstractmethod @abstractmethod
def compile_response(self, response: List) -> str: def compile_response(self, response: List | str) -> str:
pass pass
@abstractmethod @abstractmethod
def _render_response(self, response: List, screenshot: Optional[str]) -> Optional[str]: def _render_response(self, response: List, screenshot: Optional[str]) -> dict:
pass pass
@abstractmethod @abstractmethod

View File

@@ -49,12 +49,14 @@ class BinaryOperatorAgent(OperatorAgent):
grounding_client = get_openai_async_client( grounding_client = get_openai_async_client(
grounding_model.ai_model_api.api_key, grounding_model.ai_model_api.api_base_url grounding_model.ai_model_api.api_key, grounding_model.ai_model_api.api_base_url
) )
self.grounding_agent: GroundingAgent | GroundingAgentUitars = None
if "ui-tars-1.5" in grounding_model.name: if "ui-tars-1.5" in grounding_model.name:
self.grounding_agent = GroundingAgentUitars( self.grounding_agent = GroundingAgentUitars(
grounding_model.name, grounding_client, environment_type="web", tracer=tracer grounding_model.name, grounding_client, max_iterations, environment_type="web", tracer=tracer
) )
else: else:
self.grounding_agent = GroundingAgent(grounding_model.name, grounding_client, tracer=tracer) self.grounding_agent = GroundingAgent(grounding_model.name, grounding_client, max_iterations, tracer=tracer)
async def act(self, current_state: EnvState) -> AgentActResult: async def act(self, current_state: EnvState) -> AgentActResult:
""" """
@@ -143,7 +145,7 @@ Focus on the visual action and provide all necessary context.
query_screenshot = self._get_message_images(current_message) query_screenshot = self._get_message_images(current_message)
# Construct input for visual reasoner history # Construct input for visual reasoner history
visual_reasoner_history = self._format_message_for_api(self.messages) visual_reasoner_history = {"chat": self._format_message_for_api(self.messages)}
try: try:
natural_language_action = await send_message_to_model_wrapper( natural_language_action = await send_message_to_model_wrapper(
query=query_text, query=query_text,
@@ -153,6 +155,10 @@ Focus on the visual action and provide all necessary context.
agent_chat_model=self.reasoning_model, agent_chat_model=self.reasoning_model,
tracer=self.tracer, tracer=self.tracer,
) )
if not isinstance(natural_language_action, str) or not natural_language_action.strip():
raise ValueError(f"Natural language action is empty or not a string. Got {natural_language_action}")
self.messages.append(current_message) self.messages.append(current_message)
self.messages.append(AgentMessage(role="assistant", content=natural_language_action)) self.messages.append(AgentMessage(role="assistant", content=natural_language_action))
@@ -254,7 +260,7 @@ Focus on the visual action and provide all necessary context.
self.messages.append(AgentMessage(role="environment", content=action_results_content)) self.messages.append(AgentMessage(role="environment", content=action_results_content))
async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str: async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str:
conversation_history = self._format_message_for_api(self.messages) conversation_history = {"chat": self._format_message_for_api(self.messages)}
try: try:
summary = await send_message_to_model_wrapper( summary = await send_message_to_model_wrapper(
query=summarize_prompt, query=summarize_prompt,
@@ -276,25 +282,11 @@ Focus on the visual action and provide all necessary context.
return summary return summary
def compile_response(self, response_content: Union[str, List, dict]) -> str: def compile_response(self, response_content: str | List) -> str:
"""Compile response content into a string, handling OpenAI message structures.""" """Compile response content into a string, handling OpenAI message structures."""
if isinstance(response_content, str): if isinstance(response_content, str):
return response_content return response_content
if isinstance(response_content, dict) and response_content.get("role") == "assistant":
# Grounding LLM response message (might contain tool calls)
text_content = response_content.get("content")
tool_calls = response_content.get("tool_calls")
compiled = []
if text_content:
compiled.append(text_content)
if tool_calls:
for tc in tool_calls:
compiled.append(
f"**Action ({tc.get('function', {}).get('name')})**: {tc.get('function', {}).get('arguments')}"
)
return "\n- ".join(filter(None, compiled))
if isinstance(response_content, list): # Tool results list if isinstance(response_content, list): # Tool results list
compiled = ["**Tool Results**:"] compiled = ["**Tool Results**:"]
for item in response_content: for item in response_content:
@@ -336,7 +328,7 @@ Focus on the visual action and provide all necessary context.
} }
for message in messages for message in messages
] ]
return {"chat": formatted_messages} return formatted_messages
def reset(self): def reset(self):
"""Reset the agent state.""" """Reset the agent state."""

View File

@@ -2,7 +2,7 @@ import json
import logging import logging
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional, cast
from openai.types.responses import Response, ResponseOutputItem from openai.types.responses import Response, ResponseOutputItem
@@ -95,7 +95,7 @@ class OpenAIOperatorAgent(OperatorAgent):
logger.debug(f"Openai response: {response.model_dump_json()}") logger.debug(f"Openai response: {response.model_dump_json()}")
self.messages += [AgentMessage(role="environment", content=response.output)] self.messages += [AgentMessage(role="environment", content=response.output)]
rendered_response = await self._render_response(response.output, current_state.screenshot) rendered_response = self._render_response(response.output, current_state.screenshot)
last_call_id = None last_call_id = None
content = None content = None
@@ -193,7 +193,7 @@ class OpenAIOperatorAgent(OperatorAgent):
for idx, env_step in enumerate(env_steps): for idx, env_step in enumerate(env_steps):
action_result = agent_action.action_results[idx] action_result = agent_action.action_results[idx]
result_content = env_step.error or env_step.output or "[Action completed]" result_content = env_step.error or env_step.output or "[Action completed]"
if env_step.type == "image": if env_step.type == "image" and isinstance(result_content, dict):
# Add screenshot data in openai message format # Add screenshot data in openai message format
action_result["output"] = { action_result["output"] = {
"type": "input_image", "type": "input_image",
@@ -215,10 +215,13 @@ class OpenAIOperatorAgent(OperatorAgent):
def _format_message_for_api(self, messages: list[AgentMessage]) -> list: def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
"""Format the message for OpenAI API.""" """Format the message for OpenAI API."""
formatted_messages = [] formatted_messages: list = []
for message in messages: for message in messages:
if message.role == "environment": if message.role == "environment":
formatted_messages.extend(message.content) if isinstance(message.content, list):
formatted_messages.extend(message.content)
else:
logger.warning(f"Expected message content list from environment, got {type(message.content)}")
else: else:
formatted_messages.append( formatted_messages.append(
{ {
@@ -228,13 +231,14 @@ class OpenAIOperatorAgent(OperatorAgent):
) )
return formatted_messages return formatted_messages
@staticmethod def compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str:
def compile_response(response_content: str | list[dict | ResponseOutputItem]) -> str:
"""Compile the response from model into a single string.""" """Compile the response from model into a single string."""
# Handle case where response content is a string. # Handle case where response content is a string.
# This is the case when response content is a user query # This is the case when response content is a user query
if is_none_or_empty(response_content) or isinstance(response_content, str): if isinstance(response_content, str):
return response_content return response_content
elif is_none_or_empty(response_content):
return ""
# Handle case where response_content is a dictionary and not ResponseOutputItem # Handle case where response_content is a dictionary and not ResponseOutputItem
# This is the case when response_content contains action results # This is the case when response_content contains action results
if not hasattr(response_content[0], "type"): if not hasattr(response_content[0], "type"):
@@ -242,6 +246,8 @@ class OpenAIOperatorAgent(OperatorAgent):
compiled_response = [""] compiled_response = [""]
for block in deepcopy(response_content): for block in deepcopy(response_content):
block = cast(ResponseOutputItem, block) # Ensure block is of type ResponseOutputItem
# Handle different block types
if block.type == "message": if block.type == "message":
# Extract text content if available # Extract text content if available
for content in block.content: for content in block.content:
@@ -254,30 +260,29 @@ class OpenAIOperatorAgent(OperatorAgent):
text_content += content.model_dump_json() text_content += content.model_dump_json()
compiled_response.append(text_content) compiled_response.append(text_content)
elif block.type == "function_call": elif block.type == "function_call":
block_input = {"action": block.name} block_function_input = {"action": block.name}
if block.name == "goto": if block.name == "goto":
try: try:
args = json.loads(block.arguments) args = json.loads(block.arguments)
block_input["url"] = args.get("url", "[Missing URL]") block_function_input["url"] = args.get("url", "[Missing URL]")
except json.JSONDecodeError: except json.JSONDecodeError:
block_input["arguments"] = block.arguments # Show raw args on error block_function_input["arguments"] = block.arguments # Show raw args on error
compiled_response.append(f"**Action**: {json.dumps(block_input)}") compiled_response.append(f"**Action**: {json.dumps(block_function_input)}")
elif block.type == "computer_call": elif block.type == "computer_call":
block_input = block.action block_computer_input = block.action
# If it's a screenshot action # If it's a screenshot action
if block_input.type == "screenshot": if block_computer_input.type == "screenshot":
# Use a placeholder for screenshot data # Use a placeholder for screenshot data
block_input_render = block_input.model_dump() block_input_render = block_computer_input.model_dump()
block_input_render["image"] = "[placeholder for screenshot]" block_input_render["image"] = "[placeholder for screenshot]"
compiled_response.append(f"**Action**: {json.dumps(block_input_render)}") compiled_response.append(f"**Action**: {json.dumps(block_input_render)}")
else: else:
compiled_response.append(f"**Action**: {block_input.model_dump_json()}") compiled_response.append(f"**Action**: {block_computer_input.model_dump_json()}")
elif block.type == "reasoning" and block.summary: elif block.type == "reasoning" and block.summary:
compiled_response.append(f"**Thought**: {block.summary}") compiled_response.append(f"**Thought**: {block.summary}")
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
@staticmethod def _render_response(self, response_content: list[ResponseOutputItem], screenshot: str | None) -> dict:
async def _render_response(response_content: list[ResponseOutputItem], screenshot: Optional[str] = None) -> dict:
"""Render OpenAI response for display, potentially including screenshots.""" """Render OpenAI response for display, potentially including screenshots."""
render_texts = [] render_texts = []
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
@@ -285,7 +290,6 @@ class OpenAIOperatorAgent(OperatorAgent):
text_content = block.text if hasattr(block, "text") else block.model_dump_json() text_content = block.text if hasattr(block, "text") else block.model_dump_json()
render_texts += [text_content] render_texts += [text_content]
elif block.type == "function_call": elif block.type == "function_call":
block_input = {"action": block.name}
if block.name == "goto": if block.name == "goto":
args = json.loads(block.arguments) args = json.loads(block.arguments)
render_texts = [f'Open URL: {args.get("url", "[Missing URL]")}'] render_texts = [f'Open URL: {args.get("url", "[Missing URL]")}']

View File

@@ -3,7 +3,7 @@ import base64
import io import io
import logging import logging
import os import os
from typing import Optional, Set from typing import Optional, Set, Union
from khoj.processor.operator.operator_actions import OperatorAction, Point from khoj.processor.operator.operator_actions import OperatorAction, Point
from khoj.processor.operator.operator_environment_base import ( from khoj.processor.operator.operator_environment_base import (
@@ -124,7 +124,7 @@ class BrowserEnvironment(Environment):
async def get_state(self) -> EnvState: async def get_state(self) -> EnvState:
if not self.page or self.page.is_closed(): if not self.page or self.page.is_closed():
return "about:blank", None return EnvState(url="about:blank", screenshot=None)
url = self.page.url url = self.page.url
screenshot = await self._get_screenshot() screenshot = await self._get_screenshot()
return EnvState(url=url, screenshot=screenshot) return EnvState(url=url, screenshot=screenshot)
@@ -134,7 +134,9 @@ class BrowserEnvironment(Environment):
return EnvStepResult(error="Browser page is not available or closed.") return EnvStepResult(error="Browser page is not available or closed.")
before_state = await self.get_state() before_state = await self.get_state()
output, error, step_type = None, None, "text" output: Optional[Union[str, dict]] = None
error: Optional[str] = None
step_type: str = "text"
try: try:
match action.type: match action.type:
case "click": case "click":
@@ -180,16 +182,16 @@ class BrowserEnvironment(Environment):
logger.debug(f"Action: {action.type} by ({scroll_x},{scroll_y}) at ({action.x},{action.y})") logger.debug(f"Action: {action.type} by ({scroll_x},{scroll_y}) at ({action.x},{action.y})")
# Otherwise use direction/amount (from Anthropic style) # Otherwise use direction/amount (from Anthropic style)
elif action.scroll_direction: elif action.scroll_direction:
dx, dy = 0, 0 dx, dy = 0.0, 0.0
amount = action.scroll_amount or 1 amount = action.scroll_amount or 1
if action.scroll_direction == "up": if action.scroll_direction == "up":
dy = -100 * amount dy = -100.0 * amount
elif action.scroll_direction == "down": elif action.scroll_direction == "down":
dy = 100 * amount dy = 100.0 * amount
elif action.scroll_direction == "left": elif action.scroll_direction == "left":
dx = -100 * amount dx = -100.0 * amount
elif action.scroll_direction == "right": elif action.scroll_direction == "right":
dx = 100 * amount dx = 100.0 * amount
if action.x is not None and action.y is not None: if action.x is not None and action.y is not None:
await self.page.mouse.move(action.x, action.y) await self.page.mouse.move(action.x, action.y)

View File

@@ -1354,7 +1354,7 @@ async def agenerate_chat_response(
compiled_references: List[Dict] = [], compiled_references: List[Dict] = [],
online_results: Dict[str, Dict] = {}, online_results: Dict[str, Dict] = {},
code_results: Dict[str, Dict] = {}, code_results: Dict[str, Dict] = {},
operator_results: Dict[str, Dict] = {}, operator_results: List[str] = [],
inferred_queries: List[str] = [], inferred_queries: List[str] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None, user: KhojUser = None,
@@ -1411,7 +1411,7 @@ async def agenerate_chat_response(
compiled_references = [] compiled_references = []
online_results = {} online_results = {}
code_results = {} code_results = {}
operator_results = {} operator_results = []
deepthought = True deepthought = True
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed) chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)