Set operator query on init. Pass summarize prompt to summarize func

The initial user query isn't updated during an operator run. So set it
when initializing the operator agent. Instead of passing it on every
call to act.

Pass summarize prompt directly to the summarize function. Let it
construct the summarize message to query vision model with.
Previously it was being passed to the add_action_results func as
previous implementation that did not use a separate summarize func.

Also rename chat_model to vision_model for a more pertinent var name.

These changes make the code cleaner and implementation more readable.
This commit is contained in:
Debanjum
2025-05-08 09:25:43 -06:00
parent 38bcba2f4b
commit e17c06b798
5 changed files with 72 additions and 95 deletions

View File

@@ -47,10 +47,10 @@ async def operate_browser(
# Initialize Agent
max_iterations = 40 # TODO: Configurable?
operator_agent: OperatorAgent
if chat_model.name.startswith("gpt-"):
operator_agent = OpenAIOperatorAgent(chat_model, max_iterations, tracer)
elif chat_model.name.startswith("claude-"):
operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer)
if reasoning_model.name.startswith("gpt-"):
operator_agent = OpenAIOperatorAgent(query, reasoning_model, max_iterations, tracer)
elif reasoning_model.name.startswith("claude-"):
operator_agent = AnthropicOperatorAgent(query, reasoning_model, max_iterations, tracer)
else:
grounding_model_name = "ui-tars-1.5-7b"
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
@@ -60,7 +60,7 @@ async def operate_browser(
or grounding_model.model_type != ChatModel.ModelType.OPENAI
):
raise ValueError("No supported visual grounding model for binary operator agent found.")
operator_agent = BinaryOperatorAgent(reasoning_model, grounding_model, max_iterations, tracer)
operator_agent = BinaryOperatorAgent(query, reasoning_model, grounding_model, max_iterations, tracer)
# Initialize Environment
if send_status_func:
@@ -87,7 +87,7 @@ async def operate_browser(
browser_state = await environment.get_state()
# 2. Agent decides action(s)
agent_result = await operator_agent.act(query, browser_state)
agent_result = await operator_agent.act(browser_state)
# Render status update
rendered_response = agent_result.rendered_response
@@ -118,8 +118,8 @@ async def operate_browser(
break
if task_completed or trigger_iteration_limit:
# Summarize results of operator run on last iteration
operator_agent.add_action_results(env_steps, agent_result, summarize_prompt)
summary_message = await operator_agent.summarize(query, browser_state)
operator_agent.add_action_results(env_steps, agent_result)
summary_message = await operator_agent.summarize(summarize_prompt, browser_state)
logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}")
break

View File

@@ -20,9 +20,9 @@ logger = logging.getLogger(__name__)
# --- Anthropic Operator Agent ---
class AnthropicOperatorAgent(OperatorAgent):
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
async def act(self, current_state: EnvState) -> AgentActResult:
client = get_anthropic_async_client(
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
)
tool_version = "2025-01-24"
betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"]
@@ -51,7 +51,7 @@ class AnthropicOperatorAgent(OperatorAgent):
</IMPORTANT>
"""
if is_none_or_empty(self.messages):
self.messages = [AgentMessage(role="user", content=query)]
self.messages = [AgentMessage(role="user", content=self.query)]
tools = [
{
@@ -77,13 +77,13 @@ class AnthropicOperatorAgent(OperatorAgent):
]
thinking = {"type": "disabled"}
if self.chat_model.name.startswith("claude-3-7"):
if self.vision_model.name.startswith("claude-3-7"):
thinking = {"type": "enabled", "budget_tokens": 1024}
messages_for_api = self._format_message_for_api(self.messages)
response = await client.beta.messages.create(
messages=messages_for_api,
model=self.chat_model.name,
model=self.vision_model.name,
system=system_prompt,
tools=tools,
betas=betas,
@@ -187,8 +187,8 @@ class AnthropicOperatorAgent(OperatorAgent):
{
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": None, # Updated by environment step
"is_error": False, # Updated by environment step
"content": None, # Updated after environment step
"is_error": False, # Updated after environment step
}
)
@@ -206,13 +206,9 @@ class AnthropicOperatorAgent(OperatorAgent):
rendered_response=rendered_response,
)
def add_action_results(
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
):
if not agent_action.action_results and not summarize_prompt:
def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult):
if not agent_action.action_results:
return
elif not agent_action.action_results:
agent_action.action_results = []
# Update action results with results of applying suggested actions on the environment
for idx, env_step in enumerate(env_steps):
@@ -236,10 +232,6 @@ class AnthropicOperatorAgent(OperatorAgent):
if env_step.error:
action_result["is_error"] = True
# If summarize prompt provided, append as text within the tool results user message
if summarize_prompt:
agent_action.action_results.append({"type": "text", "text": summarize_prompt})
# Append tool results to the message history
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]

View File

@@ -25,26 +25,26 @@ class AgentMessage(BaseModel):
class OperatorAgent(ABC):
def __init__(self, chat_model: ChatModel, max_iterations: int, tracer: dict):
self.chat_model = chat_model
def __init__(self, query: str, vision_model: ChatModel, max_iterations: int, tracer: dict):
self.query = query
self.vision_model = vision_model
self.max_iterations = max_iterations
self.tracer = tracer
self.messages: List[AgentMessage] = []
@abstractmethod
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
async def act(self, current_state: EnvState) -> AgentActResult:
pass
@abstractmethod
def add_action_results(
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
) -> None:
def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None:
"""Track results of agent actions on the environment."""
pass
async def summarize(self, query: str, current_state: EnvState) -> str:
async def summarize(self, summarize_prompt: str, current_state: EnvState) -> str:
"""Summarize the agent's actions and results."""
await self.act(query, current_state)
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
await self.act(current_state)
if not self.messages:
return "No actions to summarize."
return self.compile_response(self.messages[-1].content)
@@ -63,12 +63,12 @@ class OperatorAgent(ABC):
def _update_usage(self, input_tokens: int, output_tokens: int, cache_read: int = 0, cache_write: int = 0):
self.tracer["usage"] = get_chat_usage_metrics(
self.chat_model.name, input_tokens, output_tokens, cache_read, cache_write, usage=self.tracer.get("usage")
self.vision_model.name, input_tokens, output_tokens, cache_read, cache_write, usage=self.tracer.get("usage")
)
logger.debug(f"Operator usage by {self.chat_model.model_type}: {self.tracer['usage']}")
logger.debug(f"Operator usage by {self.vision_model.model_type}: {self.tracer['usage']}")
def _commit_trace(self):
self.tracer["chat_model"] = self.chat_model.name
self.tracer["chat_model"] = self.vision_model.name
if is_promptrace_enabled() and len(self.messages) > 1:
compiled_messages = [
AgentMessage(role=msg.role, content=self.compile_response(msg.content)) for msg in self.messages

View File

@@ -3,7 +3,6 @@ import logging
from datetime import datetime
from typing import List, Optional
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion
from khoj.database.models import ChatModel
@@ -18,7 +17,6 @@ from khoj.processor.operator.operator_environment_base import EnvState, EnvStepR
from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import (
convert_image_to_png,
get_chat_usage_metrics,
get_openai_async_client,
is_none_or_empty,
)
@@ -29,34 +27,35 @@ logger = logging.getLogger(__name__)
# --- Binary Operator Agent ---
class BinaryOperatorAgent(OperatorAgent):
"""
An OperatorAgent that uses two LLMs (OpenAI compatible):
1. Vision LLM: Determines the next high-level action based on the visual state.
An OperatorAgent that uses two LLMs:
1. Reasoning LLM: Determines the next high-level action based on the objective and current visual reasoning trajectory.
2. Grounding LLM: Converts the high-level action into specific, executable browser actions.
"""
def __init__(
self,
vision_chat_model: ChatModel,
grounding_chat_model: ChatModel, # Assuming a second model is provided/configured
query: str,
reasoning_model: ChatModel,
grounding_model: ChatModel,
max_iterations: int,
tracer: dict,
):
super().__init__(vision_chat_model, max_iterations, tracer) # Use vision model for primary tracking
self.vision_chat_model = vision_chat_model
self.grounding_chat_model = grounding_chat_model
# Initialize OpenAI clients
self.grounding_client: AsyncOpenAI = get_openai_async_client(
grounding_chat_model.ai_model_api.api_key, grounding_chat_model.ai_model_api.api_base_url
super().__init__(query, reasoning_model, max_iterations, tracer) # Use reasoning model for primary tracking
self.reasoning_model = reasoning_model
self.grounding_model = grounding_model
# Initialize openai api compatible client for grounding model
self.grounding_client = get_openai_async_client(
grounding_model.ai_model_api.api_key, grounding_model.ai_model_api.api_base_url
)
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
async def act(self, current_state: EnvState) -> AgentActResult:
"""
Uses a two-step LLM process to determine and structure the next action.
"""
self._commit_trace() # Commit trace before next action
# --- Step 1: Reasoning LLM determines high-level action ---
reasoner_response = await self.act_reason(query, current_state)
reasoner_response = await self.act_reason(current_state)
natural_language_action = reasoner_response["message"]
if reasoner_response["type"] == "error":
logger.error(natural_language_action)
@@ -75,7 +74,7 @@ class BinaryOperatorAgent(OperatorAgent):
# --- Step 2: Grounding LLM converts NL action to structured action ---
return await self.act_ground(natural_language_action, current_state)
async def act_reason(self, query: str, current_state: EnvState) -> dict[str, str]:
async def act_reason(self, current_state: EnvState) -> dict[str, str]:
"""
Uses the reasoning LLM to determine the next high-level action based on the operation trajectory.
"""
@@ -118,12 +117,12 @@ Focus on the visual action and provide all necessary context.
""".strip()
if is_none_or_empty(self.messages):
query_text = f"**Main Objective**: {query}"
query_text = f"**Main Objective**: {self.query}"
query_screenshot = [f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}"]
first_message_content = construct_structured_message(
message=query_text,
images=query_screenshot,
model_type=self.vision_chat_model.model_type,
model_type=self.reasoning_model.model_type,
vision_enabled=True,
)
current_message = AgentMessage(role="user", content=first_message_content)
@@ -140,7 +139,7 @@ Focus on the visual action and provide all necessary context.
query_images=query_screenshot,
system_message=reasoning_system_prompt,
conversation_log=visual_reasoner_history,
agent_chat_model=self.vision_chat_model,
agent_chat_model=self.reasoning_model,
tracer=self.tracer,
)
self.messages.append(current_message)
@@ -371,7 +370,7 @@ back() # Use this to go back to the previous page.
try:
grounding_response: ChatCompletion = await self.grounding_client.chat.completions.create(
model=self.grounding_chat_model.name,
model=self.grounding_model.name,
messages=grounding_messages_for_api,
tools=grounding_tools,
tool_choice="required",
@@ -465,14 +464,12 @@ back() # Use this to go back to the previous page.
rendered_response=rendered_response,
)
def add_action_results(
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
) -> None:
def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None:
"""
Adds the results of executed actions back into the message history,
formatted for the next OpenAI vision LLM call.
"""
if not agent_action.action_results and not summarize_prompt:
if not agent_action.action_results:
return
tool_outputs = []
@@ -493,44 +490,38 @@ back() # Use this to go back to the previous page.
tool_output_content = construct_structured_message(
message=tool_outputs_str,
images=[formatted_screenshot],
model_type=self.vision_chat_model.model_type,
model_type=self.reasoning_model.model_type,
vision_enabled=True,
)
self.messages.append(AgentMessage(role="environment", content=tool_output_content))
# Append summarize prompt if provided
if summarize_prompt:
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
async def summarize(self, query: str, env_state: EnvState) -> str:
# Construct vision LLM input following OpenAI format
async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str:
conversation_history = self._format_message_for_api(self.messages)
try:
summary = await send_message_to_model_wrapper(
query=query,
query=summarize_prompt,
conversation_log=conversation_history,
agent_chat_model=self.vision_chat_model,
agent_chat_model=self.reasoning_model,
tracer=self.tracer,
)
# Return last action message if no summary
# Set summary to last action message
if not summary:
return self.compile_response(self.messages[-1].content) # Compile the last action message
# Append summary messages to history
trigger_summary = AgentMessage(role="user", content=query)
summary_message = AgentMessage(role="assistant", content=summary)
self.messages.extend([trigger_summary, summary_message])
return summary
raise ValueError("Summary is empty.")
except Exception as e:
logger.error(f"Error calling Vision LLM for summary: {e}")
return f"Error generating summary: {e}"
logger.error(f"Error calling Reasoning LLM for summary: {e}")
summary = "\n".join([self._get_message_text(msg) for msg in self.messages])
# Append summary messages to history
trigger_summary = AgentMessage(role="user", content=summarize_prompt)
summary_message = AgentMessage(role="assistant", content=summary)
self.messages.extend([trigger_summary, summary_message])
return summary
def compile_response(self, response_content: Union[str, List, dict]) -> str:
"""Compile response content into a string, handling OpenAI message structures."""
if isinstance(response_content, str):
return response_content # Simple text (e.g., initial user query, vision response)
return response_content
if isinstance(response_content, dict) and response_content.get("role") == "assistant":
# Grounding LLM response message (might contain tool calls)
@@ -544,7 +535,7 @@ back() # Use this to go back to the previous page.
compiled.append(
f"**Action ({tc.get('function', {}).get('name')})**: {tc.get('function', {}).get('arguments')}"
)
return "\n- ".join(filter(None, compiled)) or "[Assistant Message]"
return "\n- ".join(filter(None, compiled))
if isinstance(response_content, list): # Tool results list
compiled = ["**Tool Results**:"]

View File

@@ -20,9 +20,9 @@ logger = logging.getLogger(__name__)
# --- Anthropic Operator Agent ---
class OpenAIOperatorAgent(OperatorAgent):
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
async def act(self, current_state: EnvState) -> AgentActResult:
client = get_openai_async_client(
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
)
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
safety_check_message = None
@@ -80,7 +80,7 @@ class OpenAIOperatorAgent(OperatorAgent):
]
if is_none_or_empty(self.messages):
self.messages = [AgentMessage(role="user", content=query)]
self.messages = [AgentMessage(role="user", content=self.query)]
messages_for_api = self._format_message_for_api(self.messages)
response: Response = await client.responses.create(
@@ -168,7 +168,7 @@ class OpenAIOperatorAgent(OperatorAgent):
action_results.append(
{
"type": f"{block.type}_output",
"output": content, # Updated by environment step
"output": content, # Updated after environment step
"call_id": last_call_id,
}
)
@@ -181,10 +181,8 @@ class OpenAIOperatorAgent(OperatorAgent):
rendered_response=rendered_response,
)
def add_action_results(
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
) -> None:
if not agent_action.action_results and not summarize_prompt:
def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None:
if not agent_action.action_results:
return
# Update action results with results of applying suggested actions on the environment
@@ -209,11 +207,7 @@ class OpenAIOperatorAgent(OperatorAgent):
# Add text data
action_result["output"] = result_content
if agent_action.action_results:
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
# Append summarize prompt as a user message after tool results
if summarize_prompt:
self.messages += [AgentMessage(role="user", content=summarize_prompt)]
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
"""Format the message for OpenAI API."""