mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Resolve mypy typing errors in operator code
This commit is contained in:
@@ -144,7 +144,7 @@ async def converse_anthropic(
|
|||||||
user_query,
|
user_query,
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
online_results: Optional[Dict[str, Dict]] = None,
|
||||||
code_results: Optional[Dict[str, Dict]] = None,
|
code_results: Optional[Dict[str, Dict]] = None,
|
||||||
operator_results: Optional[Dict[str, Dict]] = None,
|
operator_results: Optional[List[str]] = None,
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: Optional[str] = "claude-3-7-sonnet-latest",
|
model: Optional[str] = "claude-3-7-sonnet-latest",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ async def converse_gemini(
|
|||||||
user_query,
|
user_query,
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
online_results: Optional[Dict[str, Dict]] = None,
|
||||||
code_results: Optional[Dict[str, Dict]] = None,
|
code_results: Optional[Dict[str, Dict]] = None,
|
||||||
operator_results: Optional[Dict[str, Dict]] = None,
|
operator_results: Optional[List[str]] = None,
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: Optional[str] = "gemini-2.0-flash",
|
model: Optional[str] = "gemini-2.0-flash",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ async def converse_openai(
|
|||||||
user_query,
|
user_query,
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
online_results: Optional[Dict[str, Dict]] = None,
|
||||||
code_results: Optional[Dict[str, Dict]] = None,
|
code_results: Optional[Dict[str, Dict]] = None,
|
||||||
operator_results: Optional[Dict[str, Dict]] = None,
|
operator_results: Optional[List[str]] = None,
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: str = "gpt-4o-mini",
|
model: str = "gpt-4o-mini",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ class GroundingAgent:
|
|||||||
self.tracer["usage"] = get_chat_usage_metrics(
|
self.tracer["usage"] = get_chat_usage_metrics(
|
||||||
self.model.name,
|
self.model.name,
|
||||||
input_tokens=grounding_response.usage.prompt_tokens,
|
input_tokens=grounding_response.usage.prompt_tokens,
|
||||||
completion_tokens=grounding_response.usage.completion_tokens,
|
output_tokens=grounding_response.usage.completion_tokens,
|
||||||
usage=self.tracer.get("usage"),
|
usage=self.tracer.get("usage"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import List
|
from typing import Any, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from openai import AzureOpenAI, OpenAI
|
from openai import AzureOpenAI, OpenAI
|
||||||
@@ -112,11 +112,11 @@ class GroundingAgentUitars:
|
|||||||
self.min_pixels = self.runtime_conf["min_pixels"]
|
self.min_pixels = self.runtime_conf["min_pixels"]
|
||||||
self.callusr_tolerance = self.runtime_conf["callusr_tolerance"]
|
self.callusr_tolerance = self.runtime_conf["callusr_tolerance"]
|
||||||
|
|
||||||
self.thoughts = []
|
self.thoughts: list[str] = []
|
||||||
self.actions = []
|
self.actions: list[list[OperatorAction]] = []
|
||||||
self.observations = []
|
self.observations: list[dict] = []
|
||||||
self.history_images = []
|
self.history_images: list[bytes] = []
|
||||||
self.history_responses = []
|
self.history_responses: list[str] = []
|
||||||
|
|
||||||
self.prompt_template = self.UITARS_USR_PROMPT_THOUGHT
|
self.prompt_template = self.UITARS_USR_PROMPT_THOUGHT
|
||||||
self.prompt_action_space = self.UITARS_NORMAL_ACTION_SPACE
|
self.prompt_action_space = self.UITARS_NORMAL_ACTION_SPACE
|
||||||
@@ -159,7 +159,7 @@ class GroundingAgentUitars:
|
|||||||
# top_k=top_k,
|
# top_k=top_k,
|
||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
)
|
)
|
||||||
prediction: str = response.choices[0].message.content.strip()
|
prediction = response.choices[0].message.content.strip()
|
||||||
self.tracer["usage"] = get_chat_usage_metrics(
|
self.tracer["usage"] = get_chat_usage_metrics(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
input_tokens=response.usage.prompt_tokens,
|
input_tokens=response.usage.prompt_tokens,
|
||||||
@@ -235,11 +235,15 @@ class GroundingAgentUitars:
|
|||||||
self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap)
|
self.parsing_response_to_action(parsed_response, obs_image_height, obs_image_width, self.input_swap)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
actions.append(
|
pass
|
||||||
self.parsing_response_to_pyautogui_code(
|
# TODO: Add PyautoguiAction when enable computer environment
|
||||||
parsed_response, obs_image_height, obs_image_width, self.input_swap
|
# actions.append(
|
||||||
)
|
# PyautoguiAction(code=
|
||||||
)
|
# self.parsing_response_to_pyautogui_code(
|
||||||
|
# parsed_response, obs_image_height, obs_image_width, self.input_swap
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
self.actions.append(actions)
|
self.actions.append(actions)
|
||||||
|
|
||||||
@@ -268,7 +272,8 @@ class GroundingAgentUitars:
|
|||||||
if len(self.history_images) > self.history_n:
|
if len(self.history_images) > self.history_n:
|
||||||
self.history_images = self.history_images[-self.history_n :]
|
self.history_images = self.history_images[-self.history_n :]
|
||||||
|
|
||||||
messages, images = [], []
|
messages: list[dict] = []
|
||||||
|
images: list[Any] = []
|
||||||
if isinstance(self.history_images, bytes):
|
if isinstance(self.history_images, bytes):
|
||||||
self.history_images = [self.history_images]
|
self.history_images = [self.history_images]
|
||||||
elif isinstance(self.history_images, np.ndarray):
|
elif isinstance(self.history_images, np.ndarray):
|
||||||
@@ -414,11 +419,11 @@ class GroundingAgentUitars:
|
|||||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||||
return round(number / factor) * factor
|
return round(number / factor) * factor
|
||||||
|
|
||||||
def ceil_by_factor(self, number: int, factor: int) -> int:
|
def ceil_by_factor(self, number: float, factor: int) -> int:
|
||||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||||
return math.ceil(number / factor) * factor
|
return math.ceil(number / factor) * factor
|
||||||
|
|
||||||
def floor_by_factor(self, number: int, factor: int) -> int:
|
def floor_by_factor(self, number: float, factor: int) -> int:
|
||||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||||
return math.floor(number / factor) * factor
|
return math.floor(number / factor) * factor
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional, cast
|
||||||
|
|
||||||
from anthropic.types.beta import BetaContentBlock
|
from anthropic.types.beta import BetaContentBlock
|
||||||
|
|
||||||
@@ -76,7 +76,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
thinking = {"type": "disabled"}
|
thinking: dict[str, str | int] = {"type": "disabled"}
|
||||||
if self.vision_model.name.startswith("claude-3-7"):
|
if self.vision_model.name.startswith("claude-3-7"):
|
||||||
thinking = {"type": "enabled", "budget_tokens": 1024}
|
thinking = {"type": "enabled", "budget_tokens": 1024}
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
|
|
||||||
logger.debug(f"Anthropic response: {response.model_dump_json()}")
|
logger.debug(f"Anthropic response: {response.model_dump_json()}")
|
||||||
self.messages.append(AgentMessage(role="assistant", content=response.content))
|
self.messages.append(AgentMessage(role="assistant", content=response.content))
|
||||||
rendered_response = await self._render_response(response.content, current_state.screenshot)
|
rendered_response = self._render_response(response.content, current_state.screenshot)
|
||||||
|
|
||||||
for block in response.content:
|
for block in response.content:
|
||||||
if block.type == "tool_use":
|
if block.type == "tool_use":
|
||||||
@@ -140,7 +140,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
elif tool_name == "left_mouse_up":
|
elif tool_name == "left_mouse_up":
|
||||||
action_to_run = MouseUpAction(button="left")
|
action_to_run = MouseUpAction(button="left")
|
||||||
elif tool_name == "type":
|
elif tool_name == "type":
|
||||||
text = tool_input.get("text")
|
text: str = tool_input.get("text")
|
||||||
if text:
|
if text:
|
||||||
action_to_run = TypeAction(text=text)
|
action_to_run = TypeAction(text=text)
|
||||||
elif tool_name == "scroll":
|
elif tool_name == "scroll":
|
||||||
@@ -152,7 +152,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
if direction:
|
if direction:
|
||||||
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y)
|
action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y)
|
||||||
elif tool_name == "key":
|
elif tool_name == "key":
|
||||||
text: str = tool_input.get("text")
|
text = tool_input.get("text")
|
||||||
if text:
|
if text:
|
||||||
action_to_run = KeypressAction(keys=text.split("+")) # Split xdotool style
|
action_to_run = KeypressAction(keys=text.split("+")) # Split xdotool style
|
||||||
elif tool_name == "hold_key":
|
elif tool_name == "hold_key":
|
||||||
@@ -214,7 +214,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
for idx, env_step in enumerate(env_steps):
|
for idx, env_step in enumerate(env_steps):
|
||||||
action_result = agent_action.action_results[idx]
|
action_result = agent_action.action_results[idx]
|
||||||
result_content = env_step.error or env_step.output or "[Action completed]"
|
result_content = env_step.error or env_step.output or "[Action completed]"
|
||||||
if env_step.type == "image":
|
if env_step.type == "image" and isinstance(result_content, dict):
|
||||||
# Add screenshot data in anthropic message format
|
# Add screenshot data in anthropic message format
|
||||||
action_result["content"] = [
|
action_result["content"] = [
|
||||||
{
|
{
|
||||||
@@ -262,12 +262,20 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
)
|
)
|
||||||
return formatted_messages
|
return formatted_messages
|
||||||
|
|
||||||
def compile_response(self, response_content: list[BetaContentBlock | Any]) -> str:
|
def compile_response(self, response_content: list[BetaContentBlock | dict] | str) -> str:
|
||||||
"""Compile Anthropic response into a single string."""
|
"""Compile Anthropic response into a single string."""
|
||||||
if is_none_or_empty(response_content) or not all(hasattr(item, "type") for item in response_content):
|
if isinstance(response_content, str):
|
||||||
return response_content
|
return response_content
|
||||||
|
elif is_none_or_empty(response_content):
|
||||||
|
return ""
|
||||||
|
# action results are a list dictionaries,
|
||||||
|
# beta content blocks are objects with a type attribute
|
||||||
|
elif isinstance(response_content[0], dict):
|
||||||
|
return json.dumps(response_content)
|
||||||
|
|
||||||
compiled_response = [""]
|
compiled_response = [""]
|
||||||
for block in deepcopy(response_content):
|
for block in deepcopy(response_content):
|
||||||
|
block = cast(BetaContentBlock, block) # Ensure block is of type BetaContentBlock
|
||||||
if block.type == "text":
|
if block.type == "text":
|
||||||
compiled_response.append(block.text)
|
compiled_response.append(block.text)
|
||||||
elif block.type == "tool_use":
|
elif block.type == "tool_use":
|
||||||
@@ -291,8 +299,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
|
|
||||||
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
|
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
|
||||||
|
|
||||||
@staticmethod
|
def _render_response(self, response_content: list[BetaContentBlock], screenshot: str | None) -> dict:
|
||||||
async def _render_response(response_content: list[BetaContentBlock], screenshot: Optional[str] = None) -> dict:
|
|
||||||
"""Render Anthropic response, potentially including actual screenshots."""
|
"""Render Anthropic response, potentially including actual screenshots."""
|
||||||
render_texts = []
|
render_texts = []
|
||||||
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
|
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
|
||||||
@@ -315,11 +322,11 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
elif "action" in block_input:
|
elif "action" in block_input:
|
||||||
action = block_input["action"]
|
action = block_input["action"]
|
||||||
if action == "type":
|
if action == "type":
|
||||||
text = block_input.get("text")
|
text: str = block_input.get("text")
|
||||||
if text:
|
if text:
|
||||||
render_texts += [f'Type "{text}"']
|
render_texts += [f'Type "{text}"']
|
||||||
elif action == "key":
|
elif action == "key":
|
||||||
text: str = block_input.get("text")
|
text = block_input.get("text")
|
||||||
if text:
|
if text:
|
||||||
render_texts += [f"Press {text}"]
|
render_texts += [f"Press {text}"]
|
||||||
elif action == "hold_key":
|
elif action == "hold_key":
|
||||||
|
|||||||
@@ -50,11 +50,11 @@ class OperatorAgent(ABC):
|
|||||||
return self.compile_response(self.messages[-1].content)
|
return self.compile_response(self.messages[-1].content)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compile_response(self, response: List) -> str:
|
def compile_response(self, response: List | str) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _render_response(self, response: List, screenshot: Optional[str]) -> Optional[str]:
|
def _render_response(self, response: List, screenshot: Optional[str]) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -49,12 +49,14 @@ class BinaryOperatorAgent(OperatorAgent):
|
|||||||
grounding_client = get_openai_async_client(
|
grounding_client = get_openai_async_client(
|
||||||
grounding_model.ai_model_api.api_key, grounding_model.ai_model_api.api_base_url
|
grounding_model.ai_model_api.api_key, grounding_model.ai_model_api.api_base_url
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.grounding_agent: GroundingAgent | GroundingAgentUitars = None
|
||||||
if "ui-tars-1.5" in grounding_model.name:
|
if "ui-tars-1.5" in grounding_model.name:
|
||||||
self.grounding_agent = GroundingAgentUitars(
|
self.grounding_agent = GroundingAgentUitars(
|
||||||
grounding_model.name, grounding_client, environment_type="web", tracer=tracer
|
grounding_model.name, grounding_client, max_iterations, environment_type="web", tracer=tracer
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.grounding_agent = GroundingAgent(grounding_model.name, grounding_client, tracer=tracer)
|
self.grounding_agent = GroundingAgent(grounding_model.name, grounding_client, max_iterations, tracer=tracer)
|
||||||
|
|
||||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||||
"""
|
"""
|
||||||
@@ -143,7 +145,7 @@ Focus on the visual action and provide all necessary context.
|
|||||||
query_screenshot = self._get_message_images(current_message)
|
query_screenshot = self._get_message_images(current_message)
|
||||||
|
|
||||||
# Construct input for visual reasoner history
|
# Construct input for visual reasoner history
|
||||||
visual_reasoner_history = self._format_message_for_api(self.messages)
|
visual_reasoner_history = {"chat": self._format_message_for_api(self.messages)}
|
||||||
try:
|
try:
|
||||||
natural_language_action = await send_message_to_model_wrapper(
|
natural_language_action = await send_message_to_model_wrapper(
|
||||||
query=query_text,
|
query=query_text,
|
||||||
@@ -153,6 +155,10 @@ Focus on the visual action and provide all necessary context.
|
|||||||
agent_chat_model=self.reasoning_model,
|
agent_chat_model=self.reasoning_model,
|
||||||
tracer=self.tracer,
|
tracer=self.tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not isinstance(natural_language_action, str) or not natural_language_action.strip():
|
||||||
|
raise ValueError(f"Natural language action is empty or not a string. Got {natural_language_action}")
|
||||||
|
|
||||||
self.messages.append(current_message)
|
self.messages.append(current_message)
|
||||||
self.messages.append(AgentMessage(role="assistant", content=natural_language_action))
|
self.messages.append(AgentMessage(role="assistant", content=natural_language_action))
|
||||||
|
|
||||||
@@ -254,7 +260,7 @@ Focus on the visual action and provide all necessary context.
|
|||||||
self.messages.append(AgentMessage(role="environment", content=action_results_content))
|
self.messages.append(AgentMessage(role="environment", content=action_results_content))
|
||||||
|
|
||||||
async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str:
|
async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str:
|
||||||
conversation_history = self._format_message_for_api(self.messages)
|
conversation_history = {"chat": self._format_message_for_api(self.messages)}
|
||||||
try:
|
try:
|
||||||
summary = await send_message_to_model_wrapper(
|
summary = await send_message_to_model_wrapper(
|
||||||
query=summarize_prompt,
|
query=summarize_prompt,
|
||||||
@@ -276,25 +282,11 @@ Focus on the visual action and provide all necessary context.
|
|||||||
|
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
def compile_response(self, response_content: Union[str, List, dict]) -> str:
|
def compile_response(self, response_content: str | List) -> str:
|
||||||
"""Compile response content into a string, handling OpenAI message structures."""
|
"""Compile response content into a string, handling OpenAI message structures."""
|
||||||
if isinstance(response_content, str):
|
if isinstance(response_content, str):
|
||||||
return response_content
|
return response_content
|
||||||
|
|
||||||
if isinstance(response_content, dict) and response_content.get("role") == "assistant":
|
|
||||||
# Grounding LLM response message (might contain tool calls)
|
|
||||||
text_content = response_content.get("content")
|
|
||||||
tool_calls = response_content.get("tool_calls")
|
|
||||||
compiled = []
|
|
||||||
if text_content:
|
|
||||||
compiled.append(text_content)
|
|
||||||
if tool_calls:
|
|
||||||
for tc in tool_calls:
|
|
||||||
compiled.append(
|
|
||||||
f"**Action ({tc.get('function', {}).get('name')})**: {tc.get('function', {}).get('arguments')}"
|
|
||||||
)
|
|
||||||
return "\n- ".join(filter(None, compiled))
|
|
||||||
|
|
||||||
if isinstance(response_content, list): # Tool results list
|
if isinstance(response_content, list): # Tool results list
|
||||||
compiled = ["**Tool Results**:"]
|
compiled = ["**Tool Results**:"]
|
||||||
for item in response_content:
|
for item in response_content:
|
||||||
@@ -336,7 +328,7 @@ Focus on the visual action and provide all necessary context.
|
|||||||
}
|
}
|
||||||
for message in messages
|
for message in messages
|
||||||
]
|
]
|
||||||
return {"chat": formatted_messages}
|
return formatted_messages
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the agent state."""
|
"""Reset the agent state."""
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional, cast
|
||||||
|
|
||||||
from openai.types.responses import Response, ResponseOutputItem
|
from openai.types.responses import Response, ResponseOutputItem
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
|
|
||||||
logger.debug(f"Openai response: {response.model_dump_json()}")
|
logger.debug(f"Openai response: {response.model_dump_json()}")
|
||||||
self.messages += [AgentMessage(role="environment", content=response.output)]
|
self.messages += [AgentMessage(role="environment", content=response.output)]
|
||||||
rendered_response = await self._render_response(response.output, current_state.screenshot)
|
rendered_response = self._render_response(response.output, current_state.screenshot)
|
||||||
|
|
||||||
last_call_id = None
|
last_call_id = None
|
||||||
content = None
|
content = None
|
||||||
@@ -193,7 +193,7 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
for idx, env_step in enumerate(env_steps):
|
for idx, env_step in enumerate(env_steps):
|
||||||
action_result = agent_action.action_results[idx]
|
action_result = agent_action.action_results[idx]
|
||||||
result_content = env_step.error or env_step.output or "[Action completed]"
|
result_content = env_step.error or env_step.output or "[Action completed]"
|
||||||
if env_step.type == "image":
|
if env_step.type == "image" and isinstance(result_content, dict):
|
||||||
# Add screenshot data in openai message format
|
# Add screenshot data in openai message format
|
||||||
action_result["output"] = {
|
action_result["output"] = {
|
||||||
"type": "input_image",
|
"type": "input_image",
|
||||||
@@ -215,10 +215,13 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
|
|
||||||
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
|
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
|
||||||
"""Format the message for OpenAI API."""
|
"""Format the message for OpenAI API."""
|
||||||
formatted_messages = []
|
formatted_messages: list = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.role == "environment":
|
if message.role == "environment":
|
||||||
formatted_messages.extend(message.content)
|
if isinstance(message.content, list):
|
||||||
|
formatted_messages.extend(message.content)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Expected message content list from environment, got {type(message.content)}")
|
||||||
else:
|
else:
|
||||||
formatted_messages.append(
|
formatted_messages.append(
|
||||||
{
|
{
|
||||||
@@ -228,13 +231,14 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
)
|
)
|
||||||
return formatted_messages
|
return formatted_messages
|
||||||
|
|
||||||
@staticmethod
|
def compile_response(self, response_content: str | list[dict | ResponseOutputItem]) -> str:
|
||||||
def compile_response(response_content: str | list[dict | ResponseOutputItem]) -> str:
|
|
||||||
"""Compile the response from model into a single string."""
|
"""Compile the response from model into a single string."""
|
||||||
# Handle case where response content is a string.
|
# Handle case where response content is a string.
|
||||||
# This is the case when response content is a user query
|
# This is the case when response content is a user query
|
||||||
if is_none_or_empty(response_content) or isinstance(response_content, str):
|
if isinstance(response_content, str):
|
||||||
return response_content
|
return response_content
|
||||||
|
elif is_none_or_empty(response_content):
|
||||||
|
return ""
|
||||||
# Handle case where response_content is a dictionary and not ResponseOutputItem
|
# Handle case where response_content is a dictionary and not ResponseOutputItem
|
||||||
# This is the case when response_content contains action results
|
# This is the case when response_content contains action results
|
||||||
if not hasattr(response_content[0], "type"):
|
if not hasattr(response_content[0], "type"):
|
||||||
@@ -242,6 +246,8 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
|
|
||||||
compiled_response = [""]
|
compiled_response = [""]
|
||||||
for block in deepcopy(response_content):
|
for block in deepcopy(response_content):
|
||||||
|
block = cast(ResponseOutputItem, block) # Ensure block is of type ResponseOutputItem
|
||||||
|
# Handle different block types
|
||||||
if block.type == "message":
|
if block.type == "message":
|
||||||
# Extract text content if available
|
# Extract text content if available
|
||||||
for content in block.content:
|
for content in block.content:
|
||||||
@@ -254,30 +260,29 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
text_content += content.model_dump_json()
|
text_content += content.model_dump_json()
|
||||||
compiled_response.append(text_content)
|
compiled_response.append(text_content)
|
||||||
elif block.type == "function_call":
|
elif block.type == "function_call":
|
||||||
block_input = {"action": block.name}
|
block_function_input = {"action": block.name}
|
||||||
if block.name == "goto":
|
if block.name == "goto":
|
||||||
try:
|
try:
|
||||||
args = json.loads(block.arguments)
|
args = json.loads(block.arguments)
|
||||||
block_input["url"] = args.get("url", "[Missing URL]")
|
block_function_input["url"] = args.get("url", "[Missing URL]")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
block_input["arguments"] = block.arguments # Show raw args on error
|
block_function_input["arguments"] = block.arguments # Show raw args on error
|
||||||
compiled_response.append(f"**Action**: {json.dumps(block_input)}")
|
compiled_response.append(f"**Action**: {json.dumps(block_function_input)}")
|
||||||
elif block.type == "computer_call":
|
elif block.type == "computer_call":
|
||||||
block_input = block.action
|
block_computer_input = block.action
|
||||||
# If it's a screenshot action
|
# If it's a screenshot action
|
||||||
if block_input.type == "screenshot":
|
if block_computer_input.type == "screenshot":
|
||||||
# Use a placeholder for screenshot data
|
# Use a placeholder for screenshot data
|
||||||
block_input_render = block_input.model_dump()
|
block_input_render = block_computer_input.model_dump()
|
||||||
block_input_render["image"] = "[placeholder for screenshot]"
|
block_input_render["image"] = "[placeholder for screenshot]"
|
||||||
compiled_response.append(f"**Action**: {json.dumps(block_input_render)}")
|
compiled_response.append(f"**Action**: {json.dumps(block_input_render)}")
|
||||||
else:
|
else:
|
||||||
compiled_response.append(f"**Action**: {block_input.model_dump_json()}")
|
compiled_response.append(f"**Action**: {block_computer_input.model_dump_json()}")
|
||||||
elif block.type == "reasoning" and block.summary:
|
elif block.type == "reasoning" and block.summary:
|
||||||
compiled_response.append(f"**Thought**: {block.summary}")
|
compiled_response.append(f"**Thought**: {block.summary}")
|
||||||
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
|
return "\n- ".join(filter(None, compiled_response)) # Filter out empty strings
|
||||||
|
|
||||||
@staticmethod
|
def _render_response(self, response_content: list[ResponseOutputItem], screenshot: str | None) -> dict:
|
||||||
async def _render_response(response_content: list[ResponseOutputItem], screenshot: Optional[str] = None) -> dict:
|
|
||||||
"""Render OpenAI response for display, potentially including screenshots."""
|
"""Render OpenAI response for display, potentially including screenshots."""
|
||||||
render_texts = []
|
render_texts = []
|
||||||
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
|
for block in deepcopy(response_content): # Use deepcopy to avoid modifying original
|
||||||
@@ -285,7 +290,6 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
text_content = block.text if hasattr(block, "text") else block.model_dump_json()
|
text_content = block.text if hasattr(block, "text") else block.model_dump_json()
|
||||||
render_texts += [text_content]
|
render_texts += [text_content]
|
||||||
elif block.type == "function_call":
|
elif block.type == "function_call":
|
||||||
block_input = {"action": block.name}
|
|
||||||
if block.name == "goto":
|
if block.name == "goto":
|
||||||
args = json.loads(block.arguments)
|
args = json.loads(block.arguments)
|
||||||
render_texts = [f'Open URL: {args.get("url", "[Missing URL]")}']
|
render_texts = [f'Open URL: {args.get("url", "[Missing URL]")}']
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import base64
|
|||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Set
|
from typing import Optional, Set, Union
|
||||||
|
|
||||||
from khoj.processor.operator.operator_actions import OperatorAction, Point
|
from khoj.processor.operator.operator_actions import OperatorAction, Point
|
||||||
from khoj.processor.operator.operator_environment_base import (
|
from khoj.processor.operator.operator_environment_base import (
|
||||||
@@ -124,7 +124,7 @@ class BrowserEnvironment(Environment):
|
|||||||
|
|
||||||
async def get_state(self) -> EnvState:
|
async def get_state(self) -> EnvState:
|
||||||
if not self.page or self.page.is_closed():
|
if not self.page or self.page.is_closed():
|
||||||
return "about:blank", None
|
return EnvState(url="about:blank", screenshot=None)
|
||||||
url = self.page.url
|
url = self.page.url
|
||||||
screenshot = await self._get_screenshot()
|
screenshot = await self._get_screenshot()
|
||||||
return EnvState(url=url, screenshot=screenshot)
|
return EnvState(url=url, screenshot=screenshot)
|
||||||
@@ -134,7 +134,9 @@ class BrowserEnvironment(Environment):
|
|||||||
return EnvStepResult(error="Browser page is not available or closed.")
|
return EnvStepResult(error="Browser page is not available or closed.")
|
||||||
|
|
||||||
before_state = await self.get_state()
|
before_state = await self.get_state()
|
||||||
output, error, step_type = None, None, "text"
|
output: Optional[Union[str, dict]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
step_type: str = "text"
|
||||||
try:
|
try:
|
||||||
match action.type:
|
match action.type:
|
||||||
case "click":
|
case "click":
|
||||||
@@ -180,16 +182,16 @@ class BrowserEnvironment(Environment):
|
|||||||
logger.debug(f"Action: {action.type} by ({scroll_x},{scroll_y}) at ({action.x},{action.y})")
|
logger.debug(f"Action: {action.type} by ({scroll_x},{scroll_y}) at ({action.x},{action.y})")
|
||||||
# Otherwise use direction/amount (from Anthropic style)
|
# Otherwise use direction/amount (from Anthropic style)
|
||||||
elif action.scroll_direction:
|
elif action.scroll_direction:
|
||||||
dx, dy = 0, 0
|
dx, dy = 0.0, 0.0
|
||||||
amount = action.scroll_amount or 1
|
amount = action.scroll_amount or 1
|
||||||
if action.scroll_direction == "up":
|
if action.scroll_direction == "up":
|
||||||
dy = -100 * amount
|
dy = -100.0 * amount
|
||||||
elif action.scroll_direction == "down":
|
elif action.scroll_direction == "down":
|
||||||
dy = 100 * amount
|
dy = 100.0 * amount
|
||||||
elif action.scroll_direction == "left":
|
elif action.scroll_direction == "left":
|
||||||
dx = -100 * amount
|
dx = -100.0 * amount
|
||||||
elif action.scroll_direction == "right":
|
elif action.scroll_direction == "right":
|
||||||
dx = 100 * amount
|
dx = 100.0 * amount
|
||||||
|
|
||||||
if action.x is not None and action.y is not None:
|
if action.x is not None and action.y is not None:
|
||||||
await self.page.mouse.move(action.x, action.y)
|
await self.page.mouse.move(action.x, action.y)
|
||||||
|
|||||||
@@ -1354,7 +1354,7 @@ async def agenerate_chat_response(
|
|||||||
compiled_references: List[Dict] = [],
|
compiled_references: List[Dict] = [],
|
||||||
online_results: Dict[str, Dict] = {},
|
online_results: Dict[str, Dict] = {},
|
||||||
code_results: Dict[str, Dict] = {},
|
code_results: Dict[str, Dict] = {},
|
||||||
operator_results: Dict[str, Dict] = {},
|
operator_results: List[str] = [],
|
||||||
inferred_queries: List[str] = [],
|
inferred_queries: List[str] = [],
|
||||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
@@ -1411,7 +1411,7 @@ async def agenerate_chat_response(
|
|||||||
compiled_references = []
|
compiled_references = []
|
||||||
online_results = {}
|
online_results = {}
|
||||||
code_results = {}
|
code_results = {}
|
||||||
operator_results = {}
|
operator_results = []
|
||||||
deepthought = True
|
deepthought = True
|
||||||
|
|
||||||
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)
|
chat_model = await ConversationAdapters.aget_valid_chat_model(user, conversation, is_subscribed)
|
||||||
|
|||||||
Reference in New Issue
Block a user