Resolve mypy typing errors in operator code

This commit is contained in:
Debanjum
2025-05-09 19:51:57 -06:00
parent 33689feb91
commit 95f211d03c
11 changed files with 92 additions and 82 deletions

View File

@@ -144,7 +144,7 @@ async def converse_anthropic(
user_query,
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[List[str]] = None,
conversation_log={},
model: Optional[str] = "claude-3-7-sonnet-latest",
api_key: Optional[str] = None,

View File

@@ -166,7 +166,7 @@ async def converse_gemini(
user_query,
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[List[str]] = None,
conversation_log={},
model: Optional[str] = "gemini-2.0-flash",
api_key: Optional[str] = None,

View File

@@ -169,7 +169,7 @@ async def converse_openai(
user_query,
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[Dict[str, Dict]] = None,
operator_results: Optional[List[str]] = None,
conversation_log={},
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,

View File

@@ -210,7 +210,7 @@ class GroundingAgent:
self.tracer["usage"] = get_chat_usage_metrics(
self.model.name,
input_tokens=grounding_response.usage.prompt_tokens,
completion_tokens=grounding_response.usage.completion_tokens,
output_tokens=grounding_response.usage.completion_tokens,
usage=self.tracer.get("usage"),
)
except Exception as e:

View File

@@ -10,7 +10,7 @@ import logging
import math
import re
from io import BytesIO
from typing import List
from typing import Any, List
import numpy as np
from openai import AzureOpenAI, OpenAI
@@ -112,11 +112,11 @@ class GroundingAgentUitars:
self.min_pixels = self.runtime_conf["min_pixels"]
self.callusr_tolerance = self.runtime_conf["callusr_tolerance"]
self.thoughts = []
self.actions = []
self.observations = []
self.history_images = []
self.history_responses = []
self.thoughts: list[str] = []
self.actions: list[list[OperatorAction]] = []
self.observations: list[dict] = []
self.history_images: list[bytes] = []
self.history_responses: list[str] = []
self.prompt_template = self.UITARS_USR_PROMPT_THOUGHT
self.prompt_action_space = self.UITARS_NORMAL_ACTION_SPACE
@@ -159,7 +159,7 @@ class GroundingAgentUitars:
# top_k=top_k,
top_p=self.top_p,
)
prediction: str = response.choices[0].message.content.strip()
prediction = response.choices[0].message.content.strip()
self.tracer["usage"] = get_chat_usage_metrics(
self.model_name,
input_tokens=response.usage.prompt_tokens,
@@ -235,11 +235,15 @@ class GroundingAgentUitars:
self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap)
)
else:
actions.append(
self.parsing_response_to_pyautogui_code(
parsed_response, obs_image_height, obs_image_width, self.input_swap
)
)
pass
# TODO: Add PyautoguiAction when enable computer environment
# actions.append(
# PyautoguiAction(code=
# self.parsing_response_to_pyautogui_code(
# parsed_response, obs_image_height, obs_image_width, self.input_swap
# )
# )
# )
self.actions.append(actions)
@@ -268,7 +272,8 @@ class GroundingAgentUitars:
if len(self.history_images) > self.history_n:
self.history_images = self.history_images[-self.history_n :]
messages, images = [], []
messages: list[dict] = []
images: list[Any] = []
if isinstance(self.history_images, bytes):
self.history_images = [self.history_images]
elif isinstance(self.history_images, np.ndarray):
@@ -414,11 +419,11 @@ class GroundingAgentUitars:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(self, number: int, factor: int) -> int:
def ceil_by_factor(self, number: float, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(self, number: int, factor: int) -> int:
def floor_by_factor(self, number: float, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor

View File

@@ -2,7 +2,7 @@ import json
import logging
from copy import deepcopy
from datetime import datetime
from typing import Any, List, Optional
from typing import Any, List, Optional, cast
from anthropic.types.beta import BetaContentBlock
@@ -76,7 +76,7 @@ class AnthropicOperatorAgent(OperatorAgent):
},
]
thinking = {"type": "disabled"}
thinking: dict[str, str | int] = {"type": "disabled"}
if self.vision_model.name.startswith("claude-3-7"):
thinking = {"type": "enabled", "budget_tokens": 1024}
@@ -94,7 +94,7 @@ class AnthropicOperatorAgent(OperatorAgent):
logger.debug(f"Anthropic response: {response.model_dump_json()}")
self.messages.append(AgentMessage(role="assistant", content=response.content))
rendered_response = await self._render_response(response.content, current_state.screenshot)
rendered_response = self._render_response(response.content, current_state.screenshot)
for block in response.content:
if block.type == "tool_use":
@@ -140,7 +140,7 @@ class AnthropicOperatorAgent(OperatorAgent):
elif tool_name == "left_mouse_up":
action_to_run = MouseUpAction(button="left")
elif tool_name == "type":
text = tool_input.get("text")
text: str = tool_input.get("text")
if text:
action_to_run = TypeAction(text=text)
elif tool_name == "scroll":
@@ -152,7 +152,7 @@ class AnthropicOperatorAgent(OperatorAgent):
if direction:
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y)
elif tool_name == "key":
text: str = tool_input.get("text")
text = tool_input.get("text")
if text:
action_to_run = KeypressAction(keys=text.split("+")) # Split xdotool style
elif tool_name == "hold_key":
@@ -214,7 +214,7 @@ class AnthropicOperatorAgent(OperatorAgent):
for idx, env_step in enumerate(env_steps):
action_result = agent_action.action_results[idx]
result_content = env_step.error or env_step.output or "[Action completed]"
if env_step.type == "image":
if env_step.type == "image" and isinstance(result_content, dict):
# Add screenshot data in anthropic message format
action_result["content"] = [
{
@@ -262,12 +262,20 @@ class AnthropicOperatorAgent(OperatorAgent):
)
return formatted_messages
def compile_response(self, response_content: list[BetaContentBlock | Any]) -> str:
def compile_response(self, response_content: list[BetaContentBlock | dict] | str) -> str:
"""Compile Anthropic response into a single string."""
if is_none_or_empty(response_content) or not all(hasattr(item, "type") for item in response_content):
if isinstance(response_content, str):
return response_content
elif is_none_or_empty(response_content):
return ""
# action results are a list dictionaries,
# beta content blocks are objects with a type attribute
elif isinstance(response_content[0], dict):
return json.dumps(response_content)
compiled_response = [""]
for block in deepcopy(response_content):
block = cast(BetaContentBlock, block) # Ensure block is of type BetaContentBlock
if block.type == "text":
compiled_response.append(block.text)
elif block.type == "tool_use":
@@ -291,8 +299,7 @@ class AnthropicOperatorAgent(OperatorAgent):
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
@staticmethod
async def _render_response(response_content: list[BetaContentBlock], screenshot: Optional[str] = None) -> dict:
def _render_response(self, response_content: list[BetaContentBlock], screenshot: str | None) -> dict:
"""Render Anthropic response, potentially including actual screenshots."""
render_texts = []
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
@@ -315,11 +322,11 @@ class AnthropicOperatorAgent(OperatorAgent):
elif "action" in block_input:
action = block_input["action"]
if action == "type":
text = block_input.get("text")
text: str = block_input.get("text")
if text:
render_texts += [f'Type "{text}"']
elif action == "key":
text: str = block_input.get("text")
text = block_input.get("text")
if text:
render_texts += [f"Press {text}"]
elif action == "hold_key":

View File

@@ -50,11 +50,11 @@ class OperatorAgent(ABC):
return self.compile_response(self.messages[-1].content)
@abstractmethod
def compile_response(self, response: List) -> str:
def compile_response(self, response: List | str) -> str:
pass
@abstractmethod
def _render_response(self, response: List, screenshot: Optional[str]) -> Optional[str]:
def _render_response(self, response: List, screenshot: Optional[str]) -> dict:
pass
@abstractmethod

View File

@@ -49,12 +49,14 @@ class BinaryOperatorAgent(OperatorAgent):
grounding_client = get_openai_async_client(
grounding_model.ai_model_api.api_key, grounding_model.ai_model_api.api_base_url
)
self.grounding_agent: GroundingAgent | GroundingAgentUitars = None
if "ui-tars-1.5" in grounding_model.name:
self.grounding_agent = GroundingAgentUitars(
grounding_model.name, grounding_client, environment_type="web", tracer=tracer
grounding_model.name, grounding_client, max_iterations, environment_type="web", tracer=tracer
)
else:
self.grounding_agent = GroundingAgent(grounding_model.name, grounding_client, tracer=tracer)
self.grounding_agent = GroundingAgent(grounding_model.name, grounding_client, max_iterations, tracer=tracer)
async def act(self, current_state: EnvState) -> AgentActResult:
"""
@@ -143,7 +145,7 @@ Focus on the visual action and provide all necessary context.
query_screenshot = self._get_message_images(current_message)
# Construct input for visual reasoner history
visual_reasoner_history = self._format_message_for_api(self.messages)
visual_reasoner_history = {"chat": self._format_message_for_api(self.messages)}
try:
natural_language_action = await send_message_to_model_wrapper(
query=query_text,
@@ -153,6 +155,10 @@ Focus on the visual action and provide all necessary context.
agent_chat_model=self.reasoning_model,
tracer=self.tracer,
)
if not isinstance(natural_language_action, str) or not natural_language_action.strip():
raise ValueError(f"Natural language action is empty or not a string. Got {natural_language_action}")
self.messages.append(current_message)
self.messages.append(AgentMessage(role="assistant", content=natural_language_action))
@@ -254,7 +260,7 @@ Focus on the visual action and provide all necessary context.
self.messages.append(AgentMessage(role="environment", content=action_results_content))
async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str:
conversation_history = self._format_message_for_api(self.messages)
conversation_history = {"chat": self._format_message_for_api(self.messages)}
try:
summary = await send_message_to_model_wrapper(
query=summarize_prompt,
@@ -276,25 +282,11 @@ Focus on the visual action and provide all necessary context.
return summary
def compile_response(self, response_content: Union[str, List, dict]) -> str:
def compile_response(self, response_content: str | List) -> str:
"""Compile response content into a string, handling OpenAI message structures."""
if isinstance(response_content, str):
return response_content
if isinstance(response_content, dict) and response_content.get("role") == "assistant":
# Grounding LLM response message (might contain tool calls)
text_content = response_content.get("content")
tool_calls = response_content.get("tool_calls")
compiled = []
if text_content:
compiled.append(text_content)
if tool_calls:
for tc in tool_calls:
compiled.append(
f"**Action ({tc.get('function', {}).get('name')})**: {tc.get('function', {}).get('arguments')}"
)
return "\n- ".join(filter(None, compiled))
if isinstance(response_content, list): # Tool results list
compiled = ["**Tool Results**:"]
for item in response_content:
@@ -336,7 +328,7 @@ Focus on the visual action and provide all necessary context.
}
for message in messages
]
return {"chat": formatted_messages}
return formatted_messages
def reset(self):
"""Reset the agent state."""

View File

@@ -2,7 +2,7 @@ import json
import logging
from copy import deepcopy
from datetime import datetime
from typing import List, Optional
from typing import List, Optional, cast
from openai.types.responses import Response, ResponseOutputItem
@@ -95,7 +95,7 @@ class OpenAIOperatorAgent(OperatorAgent):
logger.debug(f"Openai response: {response.model_dump_json()}")
self.messages += [AgentMessage(role="environment", content=response.output)]
rendered_response = await self._render_response(response.output, current_state.screenshot)
rendered_response = self._render_response(response.output, current_state.screenshot)
last_call_id = None
content = None
@@ -193,7 +193,7 @@ class OpenAIOperatorAgent(OperatorAgent):
for idx, env_step in enumerate(env_steps):
action_result = agent_action.action_results[idx]
result_content = env_step.error or env_step.output or "[Action completed]"
if env_step.type == "image":
if env_step.type == "image" and isinstance(result_content, dict):
# Add screenshot data in openai message format
action_result["output"] = {
"type": "input_image",
@@ -215,10 +215,13 @@ class OpenAIOperatorAgent(OperatorAgent):
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
"""Format the message for OpenAI API."""
formatted_messages = []
formatted_messages: list = []
for message in messages:
if message.role == "environment":
formatted_messages.extend(message.content)
if isinstance(message.content, list):
formatted_messages.extend(message.content)
else:
logger.warning(f"Expected message content list from environment, got {type(message.content)}")
else:
formatted_messages.append(
{
@@ -228,13 +231,14 @@ class OpenAIOperatorAgent(OperatorAgent):
)
return formatted_messages
@staticmethod
def compile_response(response_content: str | list[dict | ResponseOutputItem]) -> str:
def compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str:
"""Compile the response from model into a single string."""
# Handle case where response content is a string.
# This is the case when response content is a user query
if is_none_or_empty(response_content) or isinstance(response_content, str):
if isinstance(response_content, str):
return response_content
elif is_none_or_empty(response_content):
return ""
# Handle case where response_content is a dictionary and not ResponseOutputItem
# This is the case when response_content contains action results
if not hasattr(response_content[0], "type"):
@@ -242,6 +246,8 @@ class OpenAIOperatorAgent(OperatorAgent):
compiled_response = [""]
for block in deepcopy(response_content):
block = cast(ResponseOutputItem, block) # Ensure block is of type ResponseOutputItem
# Handle different block types
if block.type == "message":
# Extract text content if available
for content in block.content:
@@ -254,30 +260,29 @@ class OpenAIOperatorAgent(OperatorAgent):
text_content += content.model_dump_json()
compiled_response.append(text_content)
elif block.type == "function_call":
block_input = {"action": block.name}
block_function_input = {"action": block.name}
if block.name == "goto":
try:
args = json.loads(block.arguments)
block_input["url"] = args.get("url", "[Missing URL]")
block_function_input["url"] = args.get("url", "[Missing URL]")
except json.JSONDecodeError:
block_input["arguments"] = block.arguments # Show raw args on error
compiled_response.append(f"**Action**: {json.dumps(block_input)}")
block_function_input["arguments"] = block.arguments # Show raw args on error
compiled_response.append(f"**Action**: {json.dumps(block_function_input)}")
elif block.type == "computer_call":
block_input = block.action
block_computer_input = block.action
# If it's a screenshot action
if block_input.type == "screenshot":
if block_computer_input.type == "screenshot":
# Use a placeholder for screenshot data
block_input_render = block_input.model_dump()
block_input_render = block_computer_input.model_dump()
block_input_render["image"] = "[placeholder for screenshot]"
compiled_response.append(f"**Action**: {json.dumps(block_input_render)}")
else:
compiled_response.append(f"**Action**: {block_input.model_dump_json()}")
compiled_response.append(f"**Action**: {block_computer_input.model_dump_json()}")
elif block.type == "reasoning" and block.summary:
compiled_response.append(f"**Thought**: {block.summary}")
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
@staticmethod
async def _render_response(response_content: list[ResponseOutputItem], screenshot: Optional[str] = None) -> dict:
def _render_response(self, response_content: list[ResponseOutputItem], screenshot: str | None) -> dict:
"""Render OpenAI response for display, potentially including screenshots."""
render_texts = []
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
@@ -285,7 +290,6 @@ class OpenAIOperatorAgent(OperatorAgent):
text_content = block.text if hasattr(block, "text") else block.model_dump_json()
render_texts += [text_content]
elif block.type == "function_call":
block_input = {"action": block.name}
if block.name == "goto":
args = json.loads(block.arguments)
render_texts = [f'Open URL: {args.get("url", "[Missing URL]")}']

View File

@@ -3,7 +3,7 @@ import base64
import io
import logging
import os
from typing import Optional, Set
from typing import Optional, Set, Union
from khoj.processor.operator.operator_actions import OperatorAction, Point
from khoj.processor.operator.operator_environment_base import (
@@ -124,7 +124,7 @@ class BrowserEnvironment(Environment):
async def get_state(self) -> EnvState:
if not self.page or self.page.is_closed():
return "about:blank", None
return EnvState(url="about:blank", screenshot=None)
url = self.page.url
screenshot = await self._get_screenshot()
return EnvState(url=url, screenshot=screenshot)
@@ -134,7 +134,9 @@ class BrowserEnvironment(Environment):
return EnvStepResult(error="Browser page is not available or closed.")
before_state = await self.get_state()
output, error, step_type = None, None, "text"
output: Optional[Union[str, dict]] = None
error: Optional[str] = None
step_type: str = "text"
try:
match action.type:
case "click":
@@ -180,16 +182,16 @@ class BrowserEnvironment(Environment):
logger.debug(f"Action: {action.type} by ({scroll_x},{scroll_y}) at ({action.x},{action.y})")
# Otherwise use direction/amount (from Anthropic style)
elif action.scroll_direction:
dx, dy = 0, 0
dx, dy = 0.0, 0.0
amount = action.scroll_amount or 1
if action.scroll_direction == "up":
dy = -100 * amount
dy = -100.0 * amount
elif action.scroll_direction == "down":
dy = 100 * amount
dy = 100.0 * amount
elif action.scroll_direction == "left":
dx = -100 * amount
dx = -100.0 * amount
elif action.scroll_direction == "right":
dx = 100 * amount
dx = 100.0 * amount
if action.x is not None and action.y is not None:
await self.page.mouse.move(action.x, action.y)

View File

@@ -1354,7 +1354,7 @@ async def agenerate_chat_response(
compiled_references: List[Dict] = [],
online_results: Dict[str, Dict] = {},
code_results: Dict[str, Dict] = {},
operator_results: Dict[str, Dict] = {},
operator_results: List[str] = [],
inferred_queries: List[str] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None,
@@ -1411,7 +1411,7 @@ async def agenerate_chat_response(
compiled_references = []
online_results = {}
code_results = {}
operator_results = {}
operator_results = []
deepthought = True
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)