mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Set operator query on init. Pass summarize prompt to summarize func
The initial user query isn't updated during an operator run. So set it when initializing the operator agent. Instead of passing it on every call to act. Pass summarize prompt directly to the summarize function. Let it construct the summarize message to query vision model with. Previously it was being passed to the add_action_results func as previous implementation that did not use a separate summarize func. Also rename chat_model to vision_model for a more pertinent var name. These changes make the code cleaner and implementation more readable.
This commit is contained in:
@@ -47,10 +47,10 @@ async def operate_browser(
|
|||||||
# Initialize Agent
|
# Initialize Agent
|
||||||
max_iterations = 40 # TODO: Configurable?
|
max_iterations = 40 # TODO: Configurable?
|
||||||
operator_agent: OperatorAgent
|
operator_agent: OperatorAgent
|
||||||
if chat_model.name.startswith("gpt-"):
|
if reasoning_model.name.startswith("gpt-"):
|
||||||
operator_agent = OpenAIOperatorAgent(chat_model, max_iterations, tracer)
|
operator_agent = OpenAIOperatorAgent(query, reasoning_model, max_iterations, tracer)
|
||||||
elif chat_model.name.startswith("claude-"):
|
elif reasoning_model.name.startswith("claude-"):
|
||||||
operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer)
|
operator_agent = AnthropicOperatorAgent(query, reasoning_model, max_iterations, tracer)
|
||||||
else:
|
else:
|
||||||
grounding_model_name = "ui-tars-1.5-7b"
|
grounding_model_name = "ui-tars-1.5-7b"
|
||||||
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
|
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
|
||||||
@@ -60,7 +60,7 @@ async def operate_browser(
|
|||||||
or grounding_model.model_type != ChatModel.ModelType.OPENAI
|
or grounding_model.model_type != ChatModel.ModelType.OPENAI
|
||||||
):
|
):
|
||||||
raise ValueError("No supported visual grounding model for binary operator agent found.")
|
raise ValueError("No supported visual grounding model for binary operator agent found.")
|
||||||
operator_agent = BinaryOperatorAgent(reasoning_model, grounding_model, max_iterations, tracer)
|
operator_agent = BinaryOperatorAgent(query, reasoning_model, grounding_model, max_iterations, tracer)
|
||||||
|
|
||||||
# Initialize Environment
|
# Initialize Environment
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
@@ -87,7 +87,7 @@ async def operate_browser(
|
|||||||
browser_state = await environment.get_state()
|
browser_state = await environment.get_state()
|
||||||
|
|
||||||
# 2. Agent decides action(s)
|
# 2. Agent decides action(s)
|
||||||
agent_result = await operator_agent.act(query, browser_state)
|
agent_result = await operator_agent.act(browser_state)
|
||||||
|
|
||||||
# Render status update
|
# Render status update
|
||||||
rendered_response = agent_result.rendered_response
|
rendered_response = agent_result.rendered_response
|
||||||
@@ -118,8 +118,8 @@ async def operate_browser(
|
|||||||
break
|
break
|
||||||
if task_completed or trigger_iteration_limit:
|
if task_completed or trigger_iteration_limit:
|
||||||
# Summarize results of operator run on last iteration
|
# Summarize results of operator run on last iteration
|
||||||
operator_agent.add_action_results(env_steps, agent_result, summarize_prompt)
|
operator_agent.add_action_results(env_steps, agent_result)
|
||||||
summary_message = await operator_agent.summarize(query, browser_state)
|
summary_message = await operator_agent.summarize(summarize_prompt, browser_state)
|
||||||
logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}")
|
logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# --- Anthropic Operator Agent ---
|
# --- Anthropic Operator Agent ---
|
||||||
class AnthropicOperatorAgent(OperatorAgent):
|
class AnthropicOperatorAgent(OperatorAgent):
|
||||||
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||||
client = get_anthropic_async_client(
|
client = get_anthropic_async_client(
|
||||||
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url
|
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
|
||||||
)
|
)
|
||||||
tool_version = "2025-01-24"
|
tool_version = "2025-01-24"
|
||||||
betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"]
|
betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"]
|
||||||
@@ -51,7 +51,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
</IMPORTANT>
|
</IMPORTANT>
|
||||||
"""
|
"""
|
||||||
if is_none_or_empty(self.messages):
|
if is_none_or_empty(self.messages):
|
||||||
self.messages = [AgentMessage(role="user", content=query)]
|
self.messages = [AgentMessage(role="user", content=self.query)]
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
@@ -77,13 +77,13 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
]
|
]
|
||||||
|
|
||||||
thinking = {"type": "disabled"}
|
thinking = {"type": "disabled"}
|
||||||
if self.chat_model.name.startswith("claude-3-7"):
|
if self.vision_model.name.startswith("claude-3-7"):
|
||||||
thinking = {"type": "enabled", "budget_tokens": 1024}
|
thinking = {"type": "enabled", "budget_tokens": 1024}
|
||||||
|
|
||||||
messages_for_api = self._format_message_for_api(self.messages)
|
messages_for_api = self._format_message_for_api(self.messages)
|
||||||
response = await client.beta.messages.create(
|
response = await client.beta.messages.create(
|
||||||
messages=messages_for_api,
|
messages=messages_for_api,
|
||||||
model=self.chat_model.name,
|
model=self.vision_model.name,
|
||||||
system=system_prompt,
|
system=system_prompt,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
betas=betas,
|
betas=betas,
|
||||||
@@ -187,8 +187,8 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
{
|
{
|
||||||
"type": "tool_result",
|
"type": "tool_result",
|
||||||
"tool_use_id": tool_use_id,
|
"tool_use_id": tool_use_id,
|
||||||
"content": None, # Updated by environment step
|
"content": None, # Updated after environment step
|
||||||
"is_error": False, # Updated by environment step
|
"is_error": False, # Updated after environment step
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -206,13 +206,9 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
rendered_response=rendered_response,
|
rendered_response=rendered_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_action_results(
|
def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult):
|
||||||
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
|
if not agent_action.action_results:
|
||||||
):
|
|
||||||
if not agent_action.action_results and not summarize_prompt:
|
|
||||||
return
|
return
|
||||||
elif not agent_action.action_results:
|
|
||||||
agent_action.action_results = []
|
|
||||||
|
|
||||||
# Update action results with results of applying suggested actions on the environment
|
# Update action results with results of applying suggested actions on the environment
|
||||||
for idx, env_step in enumerate(env_steps):
|
for idx, env_step in enumerate(env_steps):
|
||||||
@@ -236,10 +232,6 @@ class AnthropicOperatorAgent(OperatorAgent):
|
|||||||
if env_step.error:
|
if env_step.error:
|
||||||
action_result["is_error"] = True
|
action_result["is_error"] = True
|
||||||
|
|
||||||
# If summarize prompt provided, append as text within the tool results user message
|
|
||||||
if summarize_prompt:
|
|
||||||
agent_action.action_results.append({"type": "text", "text": summarize_prompt})
|
|
||||||
|
|
||||||
# Append tool results to the message history
|
# Append tool results to the message history
|
||||||
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
||||||
|
|
||||||
|
|||||||
@@ -25,26 +25,26 @@ class AgentMessage(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class OperatorAgent(ABC):
|
class OperatorAgent(ABC):
|
||||||
def __init__(self, chat_model: ChatModel, max_iterations: int, tracer: dict):
|
def __init__(self, query: str, vision_model: ChatModel, max_iterations: int, tracer: dict):
|
||||||
self.chat_model = chat_model
|
self.query = query
|
||||||
|
self.vision_model = vision_model
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
self.tracer = tracer
|
self.tracer = tracer
|
||||||
self.messages: List[AgentMessage] = []
|
self.messages: List[AgentMessage] = []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_action_results(
|
def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None:
|
||||||
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
|
|
||||||
) -> None:
|
|
||||||
"""Track results of agent actions on the environment."""
|
"""Track results of agent actions on the environment."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def summarize(self, query: str, current_state: EnvState) -> str:
|
async def summarize(self, summarize_prompt: str, current_state: EnvState) -> str:
|
||||||
"""Summarize the agent's actions and results."""
|
"""Summarize the agent's actions and results."""
|
||||||
await self.act(query, current_state)
|
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
|
||||||
|
await self.act(current_state)
|
||||||
if not self.messages:
|
if not self.messages:
|
||||||
return "No actions to summarize."
|
return "No actions to summarize."
|
||||||
return self.compile_response(self.messages[-1].content)
|
return self.compile_response(self.messages[-1].content)
|
||||||
@@ -63,12 +63,12 @@ class OperatorAgent(ABC):
|
|||||||
|
|
||||||
def _update_usage(self, input_tokens: int, output_tokens: int, cache_read: int = 0, cache_write: int = 0):
|
def _update_usage(self, input_tokens: int, output_tokens: int, cache_read: int = 0, cache_write: int = 0):
|
||||||
self.tracer["usage"] = get_chat_usage_metrics(
|
self.tracer["usage"] = get_chat_usage_metrics(
|
||||||
self.chat_model.name, input_tokens, output_tokens, cache_read, cache_write, usage=self.tracer.get("usage")
|
self.vision_model.name, input_tokens, output_tokens, cache_read, cache_write, usage=self.tracer.get("usage")
|
||||||
)
|
)
|
||||||
logger.debug(f"Operator usage by {self.chat_model.model_type}: {self.tracer['usage']}")
|
logger.debug(f"Operator usage by {self.vision_model.model_type}: {self.tracer['usage']}")
|
||||||
|
|
||||||
def _commit_trace(self):
|
def _commit_trace(self):
|
||||||
self.tracer["chat_model"] = self.chat_model.name
|
self.tracer["chat_model"] = self.vision_model.name
|
||||||
if is_promptrace_enabled() and len(self.messages) > 1:
|
if is_promptrace_enabled() and len(self.messages) > 1:
|
||||||
compiled_messages = [
|
compiled_messages = [
|
||||||
AgentMessage(role=msg.role, content=self.compile_response(msg.content)) for msg in self.messages
|
AgentMessage(role=msg.role, content=self.compile_response(msg.content)) for msg in self.messages
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import logging
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
from openai.types.chat import ChatCompletion
|
from openai.types.chat import ChatCompletion
|
||||||
|
|
||||||
from khoj.database.models import ChatModel
|
from khoj.database.models import ChatModel
|
||||||
@@ -18,7 +17,6 @@ from khoj.processor.operator.operator_environment_base import EnvState, EnvStepR
|
|||||||
from khoj.routers.helpers import send_message_to_model_wrapper
|
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
convert_image_to_png,
|
convert_image_to_png,
|
||||||
get_chat_usage_metrics,
|
|
||||||
get_openai_async_client,
|
get_openai_async_client,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
)
|
)
|
||||||
@@ -29,34 +27,35 @@ logger = logging.getLogger(__name__)
|
|||||||
# --- Binary Operator Agent ---
|
# --- Binary Operator Agent ---
|
||||||
class BinaryOperatorAgent(OperatorAgent):
|
class BinaryOperatorAgent(OperatorAgent):
|
||||||
"""
|
"""
|
||||||
An OperatorAgent that uses two LLMs (OpenAI compatible):
|
An OperatorAgent that uses two LLMs:
|
||||||
1. Vision LLM: Determines the next high-level action based on the visual state.
|
1. Reasoning LLM: Determines the next high-level action based on the objective and current visual reasoning trajectory.
|
||||||
2. Grounding LLM: Converts the high-level action into specific, executable browser actions.
|
2. Grounding LLM: Converts the high-level action into specific, executable browser actions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vision_chat_model: ChatModel,
|
query: str,
|
||||||
grounding_chat_model: ChatModel, # Assuming a second model is provided/configured
|
reasoning_model: ChatModel,
|
||||||
|
grounding_model: ChatModel,
|
||||||
max_iterations: int,
|
max_iterations: int,
|
||||||
tracer: dict,
|
tracer: dict,
|
||||||
):
|
):
|
||||||
super().__init__(vision_chat_model, max_iterations, tracer) # Use vision model for primary tracking
|
super().__init__(query, reasoning_model, max_iterations, tracer) # Use reasoning model for primary tracking
|
||||||
self.vision_chat_model = vision_chat_model
|
self.reasoning_model = reasoning_model
|
||||||
self.grounding_chat_model = grounding_chat_model
|
self.grounding_model = grounding_model
|
||||||
# Initialize OpenAI clients
|
# Initialize openai api compatible client for grounding model
|
||||||
self.grounding_client: AsyncOpenAI = get_openai_async_client(
|
self.grounding_client = get_openai_async_client(
|
||||||
grounding_chat_model.ai_model_api.api_key, grounding_chat_model.ai_model_api.api_base_url
|
grounding_model.ai_model_api.api_key, grounding_model.ai_model_api.api_base_url
|
||||||
)
|
)
|
||||||
|
|
||||||
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||||
"""
|
"""
|
||||||
Uses a two-step LLM process to determine and structure the next action.
|
Uses a two-step LLM process to determine and structure the next action.
|
||||||
"""
|
"""
|
||||||
self._commit_trace() # Commit trace before next action
|
self._commit_trace() # Commit trace before next action
|
||||||
|
|
||||||
# --- Step 1: Reasoning LLM determines high-level action ---
|
# --- Step 1: Reasoning LLM determines high-level action ---
|
||||||
reasoner_response = await self.act_reason(query, current_state)
|
reasoner_response = await self.act_reason(current_state)
|
||||||
natural_language_action = reasoner_response["message"]
|
natural_language_action = reasoner_response["message"]
|
||||||
if reasoner_response["type"] == "error":
|
if reasoner_response["type"] == "error":
|
||||||
logger.error(natural_language_action)
|
logger.error(natural_language_action)
|
||||||
@@ -75,7 +74,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
|||||||
# --- Step 2: Grounding LLM converts NL action to structured action ---
|
# --- Step 2: Grounding LLM converts NL action to structured action ---
|
||||||
return await self.act_ground(natural_language_action, current_state)
|
return await self.act_ground(natural_language_action, current_state)
|
||||||
|
|
||||||
async def act_reason(self, query: str, current_state: EnvState) -> dict[str, str]:
|
async def act_reason(self, current_state: EnvState) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Uses the reasoning LLM to determine the next high-level action based on the operation trajectory.
|
Uses the reasoning LLM to determine the next high-level action based on the operation trajectory.
|
||||||
"""
|
"""
|
||||||
@@ -118,12 +117,12 @@ Focus on the visual action and provide all necessary context.
|
|||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
if is_none_or_empty(self.messages):
|
if is_none_or_empty(self.messages):
|
||||||
query_text = f"**Main Objective**: {query}"
|
query_text = f"**Main Objective**: {self.query}"
|
||||||
query_screenshot = [f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}"]
|
query_screenshot = [f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}"]
|
||||||
first_message_content = construct_structured_message(
|
first_message_content = construct_structured_message(
|
||||||
message=query_text,
|
message=query_text,
|
||||||
images=query_screenshot,
|
images=query_screenshot,
|
||||||
model_type=self.vision_chat_model.model_type,
|
model_type=self.reasoning_model.model_type,
|
||||||
vision_enabled=True,
|
vision_enabled=True,
|
||||||
)
|
)
|
||||||
current_message = AgentMessage(role="user", content=first_message_content)
|
current_message = AgentMessage(role="user", content=first_message_content)
|
||||||
@@ -140,7 +139,7 @@ Focus on the visual action and provide all necessary context.
|
|||||||
query_images=query_screenshot,
|
query_images=query_screenshot,
|
||||||
system_message=reasoning_system_prompt,
|
system_message=reasoning_system_prompt,
|
||||||
conversation_log=visual_reasoner_history,
|
conversation_log=visual_reasoner_history,
|
||||||
agent_chat_model=self.vision_chat_model,
|
agent_chat_model=self.reasoning_model,
|
||||||
tracer=self.tracer,
|
tracer=self.tracer,
|
||||||
)
|
)
|
||||||
self.messages.append(current_message)
|
self.messages.append(current_message)
|
||||||
@@ -371,7 +370,7 @@ back() # Use this to go back to the previous page.
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
grounding_response: ChatCompletion = await self.grounding_client.chat.completions.create(
|
grounding_response: ChatCompletion = await self.grounding_client.chat.completions.create(
|
||||||
model=self.grounding_chat_model.name,
|
model=self.grounding_model.name,
|
||||||
messages=grounding_messages_for_api,
|
messages=grounding_messages_for_api,
|
||||||
tools=grounding_tools,
|
tools=grounding_tools,
|
||||||
tool_choice="required",
|
tool_choice="required",
|
||||||
@@ -465,14 +464,12 @@ back() # Use this to go back to the previous page.
|
|||||||
rendered_response=rendered_response,
|
rendered_response=rendered_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_action_results(
|
def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None:
|
||||||
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Adds the results of executed actions back into the message history,
|
Adds the results of executed actions back into the message history,
|
||||||
formatted for the next OpenAI vision LLM call.
|
formatted for the next OpenAI vision LLM call.
|
||||||
"""
|
"""
|
||||||
if not agent_action.action_results and not summarize_prompt:
|
if not agent_action.action_results:
|
||||||
return
|
return
|
||||||
|
|
||||||
tool_outputs = []
|
tool_outputs = []
|
||||||
@@ -493,44 +490,38 @@ back() # Use this to go back to the previous page.
|
|||||||
tool_output_content = construct_structured_message(
|
tool_output_content = construct_structured_message(
|
||||||
message=tool_outputs_str,
|
message=tool_outputs_str,
|
||||||
images=[formatted_screenshot],
|
images=[formatted_screenshot],
|
||||||
model_type=self.vision_chat_model.model_type,
|
model_type=self.reasoning_model.model_type,
|
||||||
vision_enabled=True,
|
vision_enabled=True,
|
||||||
)
|
)
|
||||||
self.messages.append(AgentMessage(role="environment", content=tool_output_content))
|
self.messages.append(AgentMessage(role="environment", content=tool_output_content))
|
||||||
|
|
||||||
# Append summarize prompt if provided
|
async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str:
|
||||||
if summarize_prompt:
|
|
||||||
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
|
|
||||||
|
|
||||||
async def summarize(self, query: str, env_state: EnvState) -> str:
|
|
||||||
# Construct vision LLM input following OpenAI format
|
|
||||||
conversation_history = self._format_message_for_api(self.messages)
|
conversation_history = self._format_message_for_api(self.messages)
|
||||||
try:
|
try:
|
||||||
summary = await send_message_to_model_wrapper(
|
summary = await send_message_to_model_wrapper(
|
||||||
query=query,
|
query=summarize_prompt,
|
||||||
conversation_log=conversation_history,
|
conversation_log=conversation_history,
|
||||||
agent_chat_model=self.vision_chat_model,
|
agent_chat_model=self.reasoning_model,
|
||||||
tracer=self.tracer,
|
tracer=self.tracer,
|
||||||
)
|
)
|
||||||
|
# Set summary to last action message
|
||||||
# Return last action message if no summary
|
|
||||||
if not summary:
|
if not summary:
|
||||||
return self.compile_response(self.messages[-1].content) # Compile the last action message
|
raise ValueError("Summary is empty.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calling Reasoning LLM for summary: {e}")
|
||||||
|
summary = "\n".join([self._get_message_text(msg) for msg in self.messages])
|
||||||
|
|
||||||
# Append summary messages to history
|
# Append summary messages to history
|
||||||
trigger_summary = AgentMessage(role="user", content=query)
|
trigger_summary = AgentMessage(role="user", content=summarize_prompt)
|
||||||
summary_message = AgentMessage(role="assistant", content=summary)
|
summary_message = AgentMessage(role="assistant", content=summary)
|
||||||
self.messages.extend([trigger_summary, summary_message])
|
self.messages.extend([trigger_summary, summary_message])
|
||||||
|
|
||||||
return summary
|
return summary
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error calling Vision LLM for summary: {e}")
|
|
||||||
return f"Error generating summary: {e}"
|
|
||||||
|
|
||||||
def compile_response(self, response_content: Union[str, List, dict]) -> str:
|
def compile_response(self, response_content: Union[str, List, dict]) -> str:
|
||||||
"""Compile response content into a string, handling OpenAI message structures."""
|
"""Compile response content into a string, handling OpenAI message structures."""
|
||||||
if isinstance(response_content, str):
|
if isinstance(response_content, str):
|
||||||
return response_content # Simple text (e.g., initial user query, vision response)
|
return response_content
|
||||||
|
|
||||||
if isinstance(response_content, dict) and response_content.get("role") == "assistant":
|
if isinstance(response_content, dict) and response_content.get("role") == "assistant":
|
||||||
# Grounding LLM response message (might contain tool calls)
|
# Grounding LLM response message (might contain tool calls)
|
||||||
@@ -544,7 +535,7 @@ back() # Use this to go back to the previous page.
|
|||||||
compiled.append(
|
compiled.append(
|
||||||
f"**Action ({tc.get('function', {}).get('name')})**: {tc.get('function', {}).get('arguments')}"
|
f"**Action ({tc.get('function', {}).get('name')})**: {tc.get('function', {}).get('arguments')}"
|
||||||
)
|
)
|
||||||
return "\n- ".join(filter(None, compiled)) or "[Assistant Message]"
|
return "\n- ".join(filter(None, compiled))
|
||||||
|
|
||||||
if isinstance(response_content, list): # Tool results list
|
if isinstance(response_content, list): # Tool results list
|
||||||
compiled = ["**Tool Results**:"]
|
compiled = ["**Tool Results**:"]
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# --- Anthropic Operator Agent ---
|
# --- Anthropic Operator Agent ---
|
||||||
class OpenAIOperatorAgent(OperatorAgent):
|
class OpenAIOperatorAgent(OperatorAgent):
|
||||||
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||||
client = get_openai_async_client(
|
client = get_openai_async_client(
|
||||||
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url
|
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
|
||||||
)
|
)
|
||||||
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
|
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
|
||||||
safety_check_message = None
|
safety_check_message = None
|
||||||
@@ -80,7 +80,7 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_none_or_empty(self.messages):
|
if is_none_or_empty(self.messages):
|
||||||
self.messages = [AgentMessage(role="user", content=query)]
|
self.messages = [AgentMessage(role="user", content=self.query)]
|
||||||
|
|
||||||
messages_for_api = self._format_message_for_api(self.messages)
|
messages_for_api = self._format_message_for_api(self.messages)
|
||||||
response: Response = await client.responses.create(
|
response: Response = await client.responses.create(
|
||||||
@@ -168,7 +168,7 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
action_results.append(
|
action_results.append(
|
||||||
{
|
{
|
||||||
"type": f"{block.type}_output",
|
"type": f"{block.type}_output",
|
||||||
"output": content, # Updated by environment step
|
"output": content, # Updated after environment step
|
||||||
"call_id": last_call_id,
|
"call_id": last_call_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -181,10 +181,8 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
rendered_response=rendered_response,
|
rendered_response=rendered_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_action_results(
|
def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None:
|
||||||
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
|
if not agent_action.action_results:
|
||||||
) -> None:
|
|
||||||
if not agent_action.action_results and not summarize_prompt:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Update action results with results of applying suggested actions on the environment
|
# Update action results with results of applying suggested actions on the environment
|
||||||
@@ -209,11 +207,7 @@ class OpenAIOperatorAgent(OperatorAgent):
|
|||||||
# Add text data
|
# Add text data
|
||||||
action_result["output"] = result_content
|
action_result["output"] = result_content
|
||||||
|
|
||||||
if agent_action.action_results:
|
|
||||||
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
||||||
# Append summarize prompt as a user message after tool results
|
|
||||||
if summarize_prompt:
|
|
||||||
self.messages += [AgentMessage(role="user", content=summarize_prompt)]
|
|
||||||
|
|
||||||
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
|
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
|
||||||
"""Format the message for OpenAI API."""
|
"""Format the message for OpenAI API."""
|
||||||
|
|||||||
Reference in New Issue
Block a user