From 95f211d03cacd679a394a3620e08f8544be58834 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 9 May 2025 19:51:57 -0600 Subject: [PATCH] Resolve mypy typing errors in operator code --- .../conversation/anthropic/anthropic_chat.py | 2 +- .../conversation/google/gemini_chat.py | 2 +- src/khoj/processor/conversation/openai/gpt.py | 2 +- .../processor/operator/grounding_agent.py | 2 +- .../operator/grounding_agent_uitars.py | 35 +++++++++------- .../operator/operator_agent_anthropic.py | 31 ++++++++------ .../processor/operator/operator_agent_base.py | 4 +- .../operator/operator_agent_binary.py | 32 ++++++-------- .../operator/operator_agent_openai.py | 42 ++++++++++--------- .../operator/operator_environment_browser.py | 18 ++++---- src/khoj/routers/helpers.py | 4 +- 11 files changed, 92 insertions(+), 82 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index ffc79756..b287b754 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -144,7 +144,7 @@ async def converse_anthropic( user_query, online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, - operator_results: Optional[Dict[str, Dict]] = None, + operator_results: Optional[List[str]] = None, conversation_log={}, model: Optional[str] = "claude-3-7-sonnet-latest", api_key: Optional[str] = None, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 621d3eb5..54e9297e 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -166,7 +166,7 @@ async def converse_gemini( user_query, online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, - operator_results: Optional[Dict[str, Dict]] = None, + operator_results: Optional[List[str]] = None, conversation_log={}, model: Optional[str] = "gemini-2.0-flash", api_key: Optional[str] = None, diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index c3d38096..05c84f01 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -169,7 +169,7 @@ async def converse_openai( user_query, online_results: Optional[Dict[str, Dict]] = None, code_results: Optional[Dict[str, Dict]] = None, - operator_results: Optional[Dict[str, Dict]] = None, + operator_results: Optional[List[str]] = None, conversation_log={}, model: str = "gpt-4o-mini", api_key: Optional[str] = None, diff --git a/src/khoj/processor/operator/grounding_agent.py b/src/khoj/processor/operator/grounding_agent.py index a090c9df..65cb5c14 100644 --- a/src/khoj/processor/operator/grounding_agent.py +++ b/src/khoj/processor/operator/grounding_agent.py @@ -210,7 +210,7 @@ class GroundingAgent: self.tracer["usage"] = get_chat_usage_metrics( self.model.name, input_tokens=grounding_response.usage.prompt_tokens, - completion_tokens=grounding_response.usage.completion_tokens, + output_tokens=grounding_response.usage.completion_tokens, usage=self.tracer.get("usage"), ) except Exception as e: diff --git a/src/khoj/processor/operator/grounding_agent_uitars.py b/src/khoj/processor/operator/grounding_agent_uitars.py index e8a4978d..4ba2578b 100644 --- a/src/khoj/processor/operator/grounding_agent_uitars.py +++ b/src/khoj/processor/operator/grounding_agent_uitars.py @@ -10,7 +10,7 @@ import logging import math import re from io import BytesIO -from typing import List +from typing import Any, List import numpy as np from openai import AzureOpenAI, OpenAI @@ -112,11 +112,11 @@ class GroundingAgentUitars: self.min_pixels = self.runtime_conf["min_pixels"] self.callusr_tolerance = self.runtime_conf["callusr_tolerance"] - self.thoughts = [] - self.actions = [] - self.observations = [] - self.history_images = [] - self.history_responses = [] + self.thoughts: list[str] = [] + self.actions: list[list[OperatorAction]] = [] + self.observations: list[dict] = [] + self.history_images: list[bytes] = [] + self.history_responses: list[str] = [] self.prompt_template = self.UITARS_USR_PROMPT_THOUGHT self.prompt_action_space = self.UITARS_NORMAL_ACTION_SPACE @@ -159,7 +159,7 @@ class GroundingAgentUitars: # top_k=top_k, top_p=self.top_p, ) - prediction: str = response.choices[0].message.content.strip() + prediction = response.choices[0].message.content.strip() self.tracer["usage"] = get_chat_usage_metrics( self.model_name, input_tokens=response.usage.prompt_tokens, @@ -235,11 +235,15 @@ class GroundingAgentUitars: self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap) ) else: - actions.append( - self.parsing_response_to_pyautogui_code( - parsed_response, obs_image_height, obs_image_width, self.input_swap - ) - ) + pass + # TODO: Add PyautoguiAction when enable computer environment + # actions.append( + # PyautoguiAction(code= + # self.parsing_response_to_pyautogui_code( + # parsed_response, obs_image_height, obs_image_width, self.input_swap + # ) + # ) + # ) self.actions.append(actions) @@ -268,7 +272,8 @@ class GroundingAgentUitars: if len(self.history_images) > self.history_n: self.history_images = self.history_images[-self.history_n :] - messages, images = [], [] + messages: list[dict] = [] + images: list[Any] = [] if isinstance(self.history_images, bytes): self.history_images = [self.history_images] elif isinstance(self.history_images, np.ndarray): @@ -414,11 +419,11 @@ class GroundingAgentUitars: """Returns the closest integer to 'number' that is divisible by 'factor'.""" return round(number / factor) * factor - def ceil_by_factor(self, number: int, factor: int) -> int: + def ceil_by_factor(self, number: float, factor: int) -> int: """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" return math.ceil(number / factor) * factor - def floor_by_factor(self, number: int, factor: int) -> int: + def floor_by_factor(self, number: float, factor: int) -> int: """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" return math.floor(number / factor) * factor diff --git a/src/khoj/processor/operator/operator_agent_anthropic.py b/src/khoj/processor/operator/operator_agent_anthropic.py index 3c3b96ba..7374e6ad 100644 --- a/src/khoj/processor/operator/operator_agent_anthropic.py +++ b/src/khoj/processor/operator/operator_agent_anthropic.py @@ -2,7 +2,7 @@ import json import logging from copy import deepcopy from datetime import datetime -from typing import Any, List, Optional +from typing import Any, List, Optional, cast from anthropic.types.beta import BetaContentBlock @@ -76,7 +76,7 @@ class AnthropicOperatorAgent(OperatorAgent): }, ] - thinking = {"type": "disabled"} + thinking: dict[str, str | int] = {"type": "disabled"} if self.vision_model.name.startswith("claude-3-7"): thinking = {"type": "enabled", "budget_tokens": 1024} @@ -94,7 +94,7 @@ class AnthropicOperatorAgent(OperatorAgent): logger.debug(f"Anthropic response: {response.model_dump_json()}") self.messages.append(AgentMessage(role="assistant", content=response.content)) - rendered_response = await self._render_response(response.content, current_state.screenshot) + rendered_response = self._render_response(response.content, current_state.screenshot) for block in response.content: if block.type == "tool_use": @@ -140,7 +140,7 @@ class AnthropicOperatorAgent(OperatorAgent): elif tool_name == "left_mouse_up": action_to_run = MouseUpAction(button="left") elif tool_name == "type": - text = tool_input.get("text") + text: str = tool_input.get("text") if text: action_to_run = TypeAction(text=text) elif tool_name == "scroll": @@ -152,7 +152,7 @@ class AnthropicOperatorAgent(OperatorAgent): if direction: action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y) elif tool_name == "key": - text: str = tool_input.get("text") + text = tool_input.get("text") if text: action_to_run = KeypressAction(keys=text.split("+")) # Split xdotool style elif tool_name == "hold_key": @@ -214,7 +214,7 @@ class AnthropicOperatorAgent(OperatorAgent): for idx, env_step in enumerate(env_steps): action_result = agent_action.action_results[idx] result_content = env_step.error or env_step.output or "[Action completed]" - if env_step.type == "image": + if env_step.type == "image" and isinstance(result_content, dict): # Add screenshot data in anthropic message format action_result["content"] = [ { @@ -262,12 +262,20 @@ class AnthropicOperatorAgent(OperatorAgent): ) return formatted_messages - def compile_response(self, response_content: list[BetaContentBlock | Any]) -> str: + def compile_response(self, response_content: list[BetaContentBlock | dict] | str) -> str: """Compile Anthropic response into a single string.""" - if is_none_or_empty(response_content) or not all(hasattr(item, "type") for item in response_content): + if isinstance(response_content, str): return response_content + elif is_none_or_empty(response_content): + return "" + # action results are a list dictionaries, + # beta content blocks are objects with a type attribute + elif isinstance(response_content[0], dict): + return json.dumps(response_content) + compiled_response = [""] for block in deepcopy(response_content): + block = cast(BetaContentBlock, block) # Ensure block is of type BetaContentBlock if block.type == "text": compiled_response.append(block.text) elif block.type == "tool_use": @@ -291,8 +299,7 @@ class AnthropicOperatorAgent(OperatorAgent): return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings - @staticmethod - async def _render_response(response_content: list[BetaContentBlock], screenshot: Optional[str] = None) -> dict: + def _render_response(self, response_content: list[BetaContentBlock], screenshot: str | None) -> dict: """Render Anthropic response, potentially including actual screenshots.""" render_texts = [] for block in deepcopy(response_content): # Use deepcopy to avoid modifying original @@ -315,11 +322,11 @@ class AnthropicOperatorAgent(OperatorAgent): elif "action" in block_input: action = block_input["action"] if action == "type": - text = block_input.get("text") + text: str = block_input.get("text") if text: render_texts += [f'Type "{text}"'] elif action == "key": - text: str = block_input.get("text") + text = block_input.get("text") if text: render_texts += [f"Press {text}"] elif action == "hold_key": diff --git a/src/khoj/processor/operator/operator_agent_base.py b/src/khoj/processor/operator/operator_agent_base.py index 0b6e0d6b..430a52d6 100644 --- a/src/khoj/processor/operator/operator_agent_base.py +++ b/src/khoj/processor/operator/operator_agent_base.py @@ -50,11 +50,11 @@ class OperatorAgent(ABC): return self.compile_response(self.messages[-1].content) @abstractmethod - def compile_response(self, response: List) -> str: + def compile_response(self, response: List | str) -> str: pass @abstractmethod - def _render_response(self, response: List, screenshot: Optional[str]) -> Optional[str]: + def _render_response(self, response: List, screenshot: Optional[str]) -> dict: pass @abstractmethod diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index 58c1d2ea..0b274f57 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -49,12 +49,14 @@ class BinaryOperatorAgent(OperatorAgent): grounding_client = get_openai_async_client( grounding_model.ai_model_api.api_key, grounding_model.ai_model_api.api_base_url ) + + self.grounding_agent: GroundingAgent | GroundingAgentUitars = None if "ui-tars-1.5" in grounding_model.name: self.grounding_agent = GroundingAgentUitars( - grounding_model.name, grounding_client, environment_type="web", tracer=tracer + grounding_model.name, grounding_client, max_iterations, environment_type="web", tracer=tracer ) else: - self.grounding_agent = GroundingAgent(grounding_model.name, grounding_client, tracer=tracer) + self.grounding_agent = GroundingAgent(grounding_model.name, grounding_client, max_iterations, tracer=tracer) async def act(self, current_state: EnvState) -> AgentActResult: """ @@ -143,7 +145,7 @@ Focus on the visual action and provide all necessary context. query_screenshot = self._get_message_images(current_message) # Construct input for visual reasoner history - visual_reasoner_history = self._format_message_for_api(self.messages) + visual_reasoner_history = {"chat": self._format_message_for_api(self.messages)} try: natural_language_action = await send_message_to_model_wrapper( query=query_text, @@ -153,6 +155,10 @@ Focus on the visual action and provide all necessary context. agent_chat_model=self.reasoning_model, tracer=self.tracer, ) + + if not isinstance(natural_language_action, str) or not natural_language_action.strip(): + raise ValueError(f"Natural language action is empty or not a string. Got {natural_language_action}") + self.messages.append(current_message) self.messages.append(AgentMessage(role="assistant", content=natural_language_action)) @@ -254,7 +260,7 @@ Focus on the visual action and provide all necessary context. self.messages.append(AgentMessage(role="environment", content=action_results_content)) async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str: - conversation_history = self._format_message_for_api(self.messages) + conversation_history = {"chat": self._format_message_for_api(self.messages)} try: summary = await send_message_to_model_wrapper( query=summarize_prompt, @@ -276,25 +282,11 @@ Focus on the visual action and provide all necessary context. return summary - def compile_response(self, response_content: Union[str, List, dict]) -> str: + def compile_response(self, response_content: str | List) -> str: """Compile response content into a string, handling OpenAI message structures.""" if isinstance(response_content, str): return response_content - if isinstance(response_content, dict) and response_content.get("role") == "assistant": - # Grounding LLM response message (might contain tool calls) - text_content = response_content.get("content") - tool_calls = response_content.get("tool_calls") - compiled = [] - if text_content: - compiled.append(text_content) - if tool_calls: - for tc in tool_calls: - compiled.append( - f"**Action ({tc.get('function', {}).get('name')})**: {tc.get('function', {}).get('arguments')}" - ) - return "\n- ".join(filter(None, compiled)) - if isinstance(response_content, list): # Tool results list compiled = ["**Tool Results**:"] for item in response_content: @@ -336,7 +328,7 @@ Focus on the visual action and provide all necessary context. } for message in messages ] - return {"chat": formatted_messages} + return formatted_messages def reset(self): """Reset the agent state.""" diff --git a/src/khoj/processor/operator/operator_agent_openai.py b/src/khoj/processor/operator/operator_agent_openai.py index 47696f8c..ad557c56 100644 --- a/src/khoj/processor/operator/operator_agent_openai.py +++ b/src/khoj/processor/operator/operator_agent_openai.py @@ -2,7 +2,7 @@ import json import logging from copy import deepcopy from datetime import datetime -from typing import List, Optional +from typing import List, Optional, cast from openai.types.responses import Response, ResponseOutputItem @@ -95,7 +95,7 @@ class OpenAIOperatorAgent(OperatorAgent): logger.debug(f"Openai response: {response.model_dump_json()}") self.messages += [AgentMessage(role="environment", content=response.output)] - rendered_response = await self._render_response(response.output, current_state.screenshot) + rendered_response = self._render_response(response.output, current_state.screenshot) last_call_id = None content = None @@ -193,7 +193,7 @@ class OpenAIOperatorAgent(OperatorAgent): for idx, env_step in enumerate(env_steps): action_result = agent_action.action_results[idx] result_content = env_step.error or env_step.output or "[Action completed]" - if env_step.type == "image": + if env_step.type == "image" and isinstance(result_content, dict): # Add screenshot data in openai message format action_result["output"] = { "type": "input_image", @@ -215,10 +215,13 @@ class OpenAIOperatorAgent(OperatorAgent): def _format_message_for_api(self, messages: list[AgentMessage]) -> list: """Format the message for OpenAI API.""" - formatted_messages = [] + formatted_messages: list = [] for message in messages: if message.role == "environment": - formatted_messages.extend(message.content) + if isinstance(message.content, list): + formatted_messages.extend(message.content) + else: + logger.warning(f"Expected message content list from environment, got {type(message.content)}") else: formatted_messages.append( { @@ -228,13 +231,14 @@ class OpenAIOperatorAgent(OperatorAgent): ) return formatted_messages - @staticmethod - def compile_response(response_content: str | list[dict | ResponseOutputItem]) -> str: + def compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str: """Compile the response from model into a single string.""" # Handle case where response content is a string. # This is the case when response content is a user query - if is_none_or_empty(response_content) or isinstance(response_content, str): + if isinstance(response_content, str): return response_content + elif is_none_or_empty(response_content): + return "" # Handle case where response_content is a dictionary and not ResponseOutputItem # This is the case when response_content contains action results if not hasattr(response_content[0], "type"): @@ -242,6 +246,8 @@ class OpenAIOperatorAgent(OperatorAgent): compiled_response = [""] for block in deepcopy(response_content): + block = cast(ResponseOutputItem, block) # Ensure block is of type ResponseOutputItem + # Handle different block types if block.type == "message": # Extract text content if available for content in block.content: @@ -254,30 +260,29 @@ class OpenAIOperatorAgent(OperatorAgent): text_content += content.model_dump_json() compiled_response.append(text_content) elif block.type == "function_call": - block_input = {"action": block.name} + block_function_input = {"action": block.name} if block.name == "goto": try: args = json.loads(block.arguments) - block_input["url"] = args.get("url", "[Missing URL]") + block_function_input["url"] = args.get("url", "[Missing URL]") except json.JSONDecodeError: - block_input["arguments"] = block.arguments # Show raw args on error - compiled_response.append(f"**Action**: {json.dumps(block_input)}") + block_function_input["arguments"] = block.arguments # Show raw args on error + compiled_response.append(f"**Action**: {json.dumps(block_function_input)}") elif block.type == "computer_call": - block_input = block.action + block_computer_input = block.action # If it's a screenshot action - if block_input.type == "screenshot": + if block_computer_input.type == "screenshot": # Use a placeholder for screenshot data - block_input_render = block_input.model_dump() + block_input_render = block_computer_input.model_dump() block_input_render["image"] = "[placeholder for screenshot]" compiled_response.append(f"**Action**: {json.dumps(block_input_render)}") else: - compiled_response.append(f"**Action**: {block_input.model_dump_json()}") + compiled_response.append(f"**Action**: {block_computer_input.model_dump_json()}") elif block.type == "reasoning" and block.summary: compiled_response.append(f"**Thought**: {block.summary}") return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings - @staticmethod - async def _render_response(response_content: list[ResponseOutputItem], screenshot: Optional[str] = None) -> dict: + def _render_response(self, response_content: list[ResponseOutputItem], screenshot: str | None) -> dict: """Render OpenAI response for display, potentially including screenshots.""" render_texts = [] for block in deepcopy(response_content): # Use deepcopy to avoid modifying original @@ -285,7 +290,6 @@ class OpenAIOperatorAgent(OperatorAgent): text_content = block.text if hasattr(block, "text") else block.model_dump_json() render_texts += [text_content] elif block.type == "function_call": - block_input = {"action": block.name} if block.name == "goto": args = json.loads(block.arguments) render_texts = [f'Open URL: {args.get("url", "[Missing URL]")}'] diff --git a/src/khoj/processor/operator/operator_environment_browser.py b/src/khoj/processor/operator/operator_environment_browser.py index 0be2ce14..d6416bea 100644 --- a/src/khoj/processor/operator/operator_environment_browser.py +++ b/src/khoj/processor/operator/operator_environment_browser.py @@ -3,7 +3,7 @@ import base64 import io import logging import os -from typing import Optional, Set +from typing import Optional, Set, Union from khoj.processor.operator.operator_actions import OperatorAction, Point from khoj.processor.operator.operator_environment_base import ( @@ -124,7 +124,7 @@ class BrowserEnvironment(Environment): async def get_state(self) -> EnvState: if not self.page or self.page.is_closed(): - return "about:blank", None + return EnvState(url="about:blank", screenshot=None) url = self.page.url screenshot = await self._get_screenshot() return EnvState(url=url, screenshot=screenshot) @@ -134,7 +134,9 @@ class BrowserEnvironment(Environment): return EnvStepResult(error="Browser page is not available or closed.") before_state = await self.get_state() - output, error, step_type = None, None, "text" + output: Optional[Union[str, dict]] = None + error: Optional[str] = None + step_type: str = "text" try: match action.type: case "click": @@ -180,16 +182,16 @@ class BrowserEnvironment(Environment): logger.debug(f"Action: {action.type} by ({scroll_x},{scroll_y}) at ({action.x},{action.y})") # Otherwise use direction/amount (from Anthropic style) elif action.scroll_direction: - dx, dy = 0, 0 + dx, dy = 0.0, 0.0 amount = action.scroll_amount or 1 if action.scroll_direction == "up": - dy = -100 * amount + dy = -100.0 * amount elif action.scroll_direction == "down": - dy = 100 * amount + dy = 100.0 * amount elif action.scroll_direction == "left": - dx = -100 * amount + dx = -100.0 * amount elif action.scroll_direction == "right": - dx = 100 * amount + dx = 100.0 * amount if action.x is not None and action.y is not None: await self.page.mouse.move(action.x, action.y) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index ad1a8433..25dac496 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1354,7 +1354,7 @@ async def agenerate_chat_response( compiled_references: List[Dict] = [], online_results: Dict[str, Dict] = {}, code_results: Dict[str, Dict] = {}, - operator_results: Dict[str, Dict] = {}, + operator_results: List[str] = [], inferred_queries: List[str] = [], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], user: KhojUser = None, @@ -1411,7 +1411,7 @@ async def agenerate_chat_response( compiled_references = [] online_results = {} code_results = {} - operator_results = {} + operator_results = [] deepthought = True chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)