Set operator query on init. Pass summarize prompt to summarize func

The initial user query isn't updated during an operator run. So set it
when initializing the operator agent. Instead of passing it on every
call to act.

Pass summarize prompt directly to the summarize function. Let it
construct the summarize message to query vision model with.
Previously it was being passed to the add_action_results func as
previous implementation that did not use a separate summarize func.

Also rename chat_model to vision_model for a more pertinent var name.

These changes make the code cleaner and implementation more readable.
This commit is contained in:
Debanjum
2025-05-08 09:25:43 -06:00
parent 38bcba2f4b
commit e17c06b798
5 changed files with 72 additions and 95 deletions

View File

@@ -47,10 +47,10 @@ async def operate_browser(
# Initialize Agent # Initialize Agent
max_iterations = 40 # TODO: Configurable? max_iterations = 40 # TODO: Configurable?
operator_agent: OperatorAgent operator_agent: OperatorAgent
if chat_model.name.startswith("gpt-"): if reasoning_model.name.startswith("gpt-"):
operator_agent = OpenAIOperatorAgent(chat_model, max_iterations, tracer) operator_agent = OpenAIOperatorAgent(query, reasoning_model, max_iterations, tracer)
elif chat_model.name.startswith("claude-"): elif reasoning_model.name.startswith("claude-"):
operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer) operator_agent = AnthropicOperatorAgent(query, reasoning_model, max_iterations, tracer)
else: else:
grounding_model_name = "ui-tars-1.5-7b" grounding_model_name = "ui-tars-1.5-7b"
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name) grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
@@ -60,7 +60,7 @@ async def operate_browser(
or grounding_model.model_type != ChatModel.ModelType.OPENAI or grounding_model.model_type != ChatModel.ModelType.OPENAI
): ):
raise ValueError("No supported visual grounding model for binary operator agent found.") raise ValueError("No supported visual grounding model for binary operator agent found.")
operator_agent = BinaryOperatorAgent(reasoning_model, grounding_model, max_iterations, tracer) operator_agent = BinaryOperatorAgent(query, reasoning_model, grounding_model, max_iterations, tracer)
# Initialize Environment # Initialize Environment
if send_status_func: if send_status_func:
@@ -87,7 +87,7 @@ async def operate_browser(
browser_state = await environment.get_state() browser_state = await environment.get_state()
# 2. Agent decides action(s) # 2. Agent decides action(s)
agent_result = await operator_agent.act(query, browser_state) agent_result = await operator_agent.act(browser_state)
# Render status update # Render status update
rendered_response = agent_result.rendered_response rendered_response = agent_result.rendered_response
@@ -118,8 +118,8 @@ async def operate_browser(
break break
if task_completed or trigger_iteration_limit: if task_completed or trigger_iteration_limit:
# Summarize results of operator run on last iteration # Summarize results of operator run on last iteration
operator_agent.add_action_results(env_steps, agent_result, summarize_prompt) operator_agent.add_action_results(env_steps, agent_result)
summary_message = await operator_agent.summarize(query, browser_state) summary_message = await operator_agent.summarize(summarize_prompt, browser_state)
logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}") logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}")
break break

View File

@@ -20,9 +20,9 @@ logger = logging.getLogger(__name__)
# --- Anthropic Operator Agent --- # --- Anthropic Operator Agent ---
class AnthropicOperatorAgent(OperatorAgent): class AnthropicOperatorAgent(OperatorAgent):
async def act(self, query: str, current_state: EnvState) -> AgentActResult: async def act(self, current_state: EnvState) -> AgentActResult:
client = get_anthropic_async_client( client = get_anthropic_async_client(
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
) )
tool_version = "2025-01-24" tool_version = "2025-01-24"
betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"] betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"]
@@ -51,7 +51,7 @@ class AnthropicOperatorAgent(OperatorAgent):
</IMPORTANT> </IMPORTANT>
""" """
if is_none_or_empty(self.messages): if is_none_or_empty(self.messages):
self.messages = [AgentMessage(role="user", content=query)] self.messages = [AgentMessage(role="user", content=self.query)]
tools = [ tools = [
{ {
@@ -77,13 +77,13 @@ class AnthropicOperatorAgent(OperatorAgent):
] ]
thinking = {"type": "disabled"} thinking = {"type": "disabled"}
if self.chat_model.name.startswith("claude-3-7"): if self.vision_model.name.startswith("claude-3-7"):
thinking = {"type": "enabled", "budget_tokens": 1024} thinking = {"type": "enabled", "budget_tokens": 1024}
messages_for_api = self._format_message_for_api(self.messages) messages_for_api = self._format_message_for_api(self.messages)
response = await client.beta.messages.create( response = await client.beta.messages.create(
messages=messages_for_api, messages=messages_for_api,
model=self.chat_model.name, model=self.vision_model.name,
system=system_prompt, system=system_prompt,
tools=tools, tools=tools,
betas=betas, betas=betas,
@@ -187,8 +187,8 @@ class AnthropicOperatorAgent(OperatorAgent):
{ {
"type": "tool_result", "type": "tool_result",
"tool_use_id": tool_use_id, "tool_use_id": tool_use_id,
"content": None, # Updated by environment step "content": None, # Updated after environment step
"is_error": False, # Updated by environment step "is_error": False, # Updated after environment step
} }
) )
@@ -206,13 +206,9 @@ class AnthropicOperatorAgent(OperatorAgent):
rendered_response=rendered_response, rendered_response=rendered_response,
) )
def add_action_results( def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult):
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None if not agent_action.action_results:
):
if not agent_action.action_results and not summarize_prompt:
return return
elif not agent_action.action_results:
agent_action.action_results = []
# Update action results with results of applying suggested actions on the environment # Update action results with results of applying suggested actions on the environment
for idx, env_step in enumerate(env_steps): for idx, env_step in enumerate(env_steps):
@@ -236,10 +232,6 @@ class AnthropicOperatorAgent(OperatorAgent):
if env_step.error: if env_step.error:
action_result["is_error"] = True action_result["is_error"] = True
# If summarize prompt provided, append as text within the tool results user message
if summarize_prompt:
agent_action.action_results.append({"type": "text", "text": summarize_prompt})
# Append tool results to the message history # Append tool results to the message history
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)] self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]

View File

@@ -25,26 +25,26 @@ class AgentMessage(BaseModel):
class OperatorAgent(ABC): class OperatorAgent(ABC):
def __init__(self, chat_model: ChatModel, max_iterations: int, tracer: dict): def __init__(self, query: str, vision_model: ChatModel, max_iterations: int, tracer: dict):
self.chat_model = chat_model self.query = query
self.vision_model = vision_model
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.tracer = tracer self.tracer = tracer
self.messages: List[AgentMessage] = [] self.messages: List[AgentMessage] = []
@abstractmethod @abstractmethod
async def act(self, query: str, current_state: EnvState) -> AgentActResult: async def act(self, current_state: EnvState) -> AgentActResult:
pass pass
@abstractmethod @abstractmethod
def add_action_results( def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None:
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
) -> None:
"""Track results of agent actions on the environment.""" """Track results of agent actions on the environment."""
pass pass
async def summarize(self, query: str, current_state: EnvState) -> str: async def summarize(self, summarize_prompt: str, current_state: EnvState) -> str:
"""Summarize the agent's actions and results.""" """Summarize the agent's actions and results."""
await self.act(query, current_state) self.messages.append(AgentMessage(role="user", content=summarize_prompt))
await self.act(current_state)
if not self.messages: if not self.messages:
return "No actions to summarize." return "No actions to summarize."
return self.compile_response(self.messages[-1].content) return self.compile_response(self.messages[-1].content)
@@ -63,12 +63,12 @@ class OperatorAgent(ABC):
def _update_usage(self, input_tokens: int, output_tokens: int, cache_read: int = 0, cache_write: int = 0): def _update_usage(self, input_tokens: int, output_tokens: int, cache_read: int = 0, cache_write: int = 0):
self.tracer["usage"] = get_chat_usage_metrics( self.tracer["usage"] = get_chat_usage_metrics(
self.chat_model.name, input_tokens, output_tokens, cache_read, cache_write, usage=self.tracer.get("usage") self.vision_model.name, input_tokens, output_tokens, cache_read, cache_write, usage=self.tracer.get("usage")
) )
logger.debug(f"Operator usage by {self.chat_model.model_type}: {self.tracer['usage']}") logger.debug(f"Operator usage by {self.vision_model.model_type}: {self.tracer['usage']}")
def _commit_trace(self): def _commit_trace(self):
self.tracer["chat_model"] = self.chat_model.name self.tracer["chat_model"] = self.vision_model.name
if is_promptrace_enabled() and len(self.messages) > 1: if is_promptrace_enabled() and len(self.messages) > 1:
compiled_messages = [ compiled_messages = [
AgentMessage(role=msg.role, content=self.compile_response(msg.content)) for msg in self.messages AgentMessage(role=msg.role, content=self.compile_response(msg.content)) for msg in self.messages

View File

@@ -3,7 +3,6 @@ import logging
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion from openai.types.chat import ChatCompletion
from khoj.database.models import ChatModel from khoj.database.models import ChatModel
@@ -18,7 +17,6 @@ from khoj.processor.operator.operator_environment_base import EnvState, EnvStepR
from khoj.routers.helpers import send_message_to_model_wrapper from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import ( from khoj.utils.helpers import (
convert_image_to_png, convert_image_to_png,
get_chat_usage_metrics,
get_openai_async_client, get_openai_async_client,
is_none_or_empty, is_none_or_empty,
) )
@@ -29,34 +27,35 @@ logger = logging.getLogger(__name__)
# --- Binary Operator Agent --- # --- Binary Operator Agent ---
class BinaryOperatorAgent(OperatorAgent): class BinaryOperatorAgent(OperatorAgent):
""" """
An OperatorAgent that uses two LLMs (OpenAI compatible): An OperatorAgent that uses two LLMs:
1. Vision LLM: Determines the next high-level action based on the visual state. 1. Reasoning LLM: Determines the next high-level action based on the objective and current visual reasoning trajectory.
2. Grounding LLM: Converts the high-level action into specific, executable browser actions. 2. Grounding LLM: Converts the high-level action into specific, executable browser actions.
""" """
def __init__( def __init__(
self, self,
vision_chat_model: ChatModel, query: str,
grounding_chat_model: ChatModel, # Assuming a second model is provided/configured reasoning_model: ChatModel,
grounding_model: ChatModel,
max_iterations: int, max_iterations: int,
tracer: dict, tracer: dict,
): ):
super().__init__(vision_chat_model, max_iterations, tracer) # Use vision model for primary tracking super().__init__(query, reasoning_model, max_iterations, tracer) # Use reasoning model for primary tracking
self.vision_chat_model = vision_chat_model self.reasoning_model = reasoning_model
self.grounding_chat_model = grounding_chat_model self.grounding_model = grounding_model
# Initialize OpenAI clients # Initialize openai api compatible client for grounding model
self.grounding_client: AsyncOpenAI = get_openai_async_client( self.grounding_client = get_openai_async_client(
grounding_chat_model.ai_model_api.api_key, grounding_chat_model.ai_model_api.api_base_url grounding_model.ai_model_api.api_key, grounding_model.ai_model_api.api_base_url
) )
async def act(self, query: str, current_state: EnvState) -> AgentActResult: async def act(self, current_state: EnvState) -> AgentActResult:
""" """
Uses a two-step LLM process to determine and structure the next action. Uses a two-step LLM process to determine and structure the next action.
""" """
self._commit_trace() # Commit trace before next action self._commit_trace() # Commit trace before next action
# --- Step 1: Reasoning LLM determines high-level action --- # --- Step 1: Reasoning LLM determines high-level action ---
reasoner_response = await self.act_reason(query, current_state) reasoner_response = await self.act_reason(current_state)
natural_language_action = reasoner_response["message"] natural_language_action = reasoner_response["message"]
if reasoner_response["type"] == "error": if reasoner_response["type"] == "error":
logger.error(natural_language_action) logger.error(natural_language_action)
@@ -75,7 +74,7 @@ class BinaryOperatorAgent(OperatorAgent):
# --- Step 2: Grounding LLM converts NL action to structured action --- # --- Step 2: Grounding LLM converts NL action to structured action ---
return await self.act_ground(natural_language_action, current_state) return await self.act_ground(natural_language_action, current_state)
async def act_reason(self, query: str, current_state: EnvState) -> dict[str, str]: async def act_reason(self, current_state: EnvState) -> dict[str, str]:
""" """
Uses the reasoning LLM to determine the next high-level action based on the operation trajectory. Uses the reasoning LLM to determine the next high-level action based on the operation trajectory.
""" """
@@ -118,12 +117,12 @@ Focus on the visual action and provide all necessary context.
""".strip() """.strip()
if is_none_or_empty(self.messages): if is_none_or_empty(self.messages):
query_text = f"**Main Objective**: {query}" query_text = f"**Main Objective**: {self.query}"
query_screenshot = [f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}"] query_screenshot = [f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}"]
first_message_content = construct_structured_message( first_message_content = construct_structured_message(
message=query_text, message=query_text,
images=query_screenshot, images=query_screenshot,
model_type=self.vision_chat_model.model_type, model_type=self.reasoning_model.model_type,
vision_enabled=True, vision_enabled=True,
) )
current_message = AgentMessage(role="user", content=first_message_content) current_message = AgentMessage(role="user", content=first_message_content)
@@ -140,7 +139,7 @@ Focus on the visual action and provide all necessary context.
query_images=query_screenshot, query_images=query_screenshot,
system_message=reasoning_system_prompt, system_message=reasoning_system_prompt,
conversation_log=visual_reasoner_history, conversation_log=visual_reasoner_history,
agent_chat_model=self.vision_chat_model, agent_chat_model=self.reasoning_model,
tracer=self.tracer, tracer=self.tracer,
) )
self.messages.append(current_message) self.messages.append(current_message)
@@ -371,7 +370,7 @@ back() # Use this to go back to the previous page.
try: try:
grounding_response: ChatCompletion = await self.grounding_client.chat.completions.create( grounding_response: ChatCompletion = await self.grounding_client.chat.completions.create(
model=self.grounding_chat_model.name, model=self.grounding_model.name,
messages=grounding_messages_for_api, messages=grounding_messages_for_api,
tools=grounding_tools, tools=grounding_tools,
tool_choice="required", tool_choice="required",
@@ -465,14 +464,12 @@ back() # Use this to go back to the previous page.
rendered_response=rendered_response, rendered_response=rendered_response,
) )
def add_action_results( def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None:
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
) -> None:
""" """
Adds the results of executed actions back into the message history, Adds the results of executed actions back into the message history,
formatted for the next OpenAI vision LLM call. formatted for the next OpenAI vision LLM call.
""" """
if not agent_action.action_results and not summarize_prompt: if not agent_action.action_results:
return return
tool_outputs = [] tool_outputs = []
@@ -493,44 +490,38 @@ back() # Use this to go back to the previous page.
tool_output_content = construct_structured_message( tool_output_content = construct_structured_message(
message=tool_outputs_str, message=tool_outputs_str,
images=[formatted_screenshot], images=[formatted_screenshot],
model_type=self.vision_chat_model.model_type, model_type=self.reasoning_model.model_type,
vision_enabled=True, vision_enabled=True,
) )
self.messages.append(AgentMessage(role="environment", content=tool_output_content)) self.messages.append(AgentMessage(role="environment", content=tool_output_content))
# Append summarize prompt if provided async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str:
if summarize_prompt:
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
async def summarize(self, query: str, env_state: EnvState) -> str:
# Construct vision LLM input following OpenAI format
conversation_history = self._format_message_for_api(self.messages) conversation_history = self._format_message_for_api(self.messages)
try: try:
summary = await send_message_to_model_wrapper( summary = await send_message_to_model_wrapper(
query=query, query=summarize_prompt,
conversation_log=conversation_history, conversation_log=conversation_history,
agent_chat_model=self.vision_chat_model, agent_chat_model=self.reasoning_model,
tracer=self.tracer, tracer=self.tracer,
) )
# Set summary to last action message
# Return last action message if no summary
if not summary: if not summary:
return self.compile_response(self.messages[-1].content) # Compile the last action message raise ValueError("Summary is empty.")
except Exception as e:
logger.error(f"Error calling Reasoning LLM for summary: {e}")
summary = "\n".join([self._get_message_text(msg) for msg in self.messages])
# Append summary messages to history # Append summary messages to history
trigger_summary = AgentMessage(role="user", content=query) trigger_summary = AgentMessage(role="user", content=summarize_prompt)
summary_message = AgentMessage(role="assistant", content=summary) summary_message = AgentMessage(role="assistant", content=summary)
self.messages.extend([trigger_summary, summary_message]) self.messages.extend([trigger_summary, summary_message])
return summary return summary
except Exception as e:
logger.error(f"Error calling Vision LLM for summary: {e}")
return f"Error generating summary: {e}"
def compile_response(self, response_content: Union[str, List, dict]) -> str: def compile_response(self, response_content: Union[str, List, dict]) -> str:
"""Compile response content into a string, handling OpenAI message structures.""" """Compile response content into a string, handling OpenAI message structures."""
if isinstance(response_content, str): if isinstance(response_content, str):
return response_content # Simple text (e.g., initial user query, vision response) return response_content
if isinstance(response_content, dict) and response_content.get("role") == "assistant": if isinstance(response_content, dict) and response_content.get("role") == "assistant":
# Grounding LLM response message (might contain tool calls) # Grounding LLM response message (might contain tool calls)
@@ -544,7 +535,7 @@ back() # Use this to go back to the previous page.
compiled.append( compiled.append(
f"**Action ({tc.get('function', {}).get('name')})**: {tc.get('function', {}).get('arguments')}" f"**Action ({tc.get('function', {}).get('name')})**: {tc.get('function', {}).get('arguments')}"
) )
return "\n- ".join(filter(None, compiled)) or "[Assistant Message]" return "\n- ".join(filter(None, compiled))
if isinstance(response_content, list): # Tool results list if isinstance(response_content, list): # Tool results list
compiled = ["**Tool Results**:"] compiled = ["**Tool Results**:"]

View File

@@ -20,9 +20,9 @@ logger = logging.getLogger(__name__)
# --- Anthropic Operator Agent --- # --- Anthropic Operator Agent ---
class OpenAIOperatorAgent(OperatorAgent): class OpenAIOperatorAgent(OperatorAgent):
async def act(self, query: str, current_state: EnvState) -> AgentActResult: async def act(self, current_state: EnvState) -> AgentActResult:
client = get_openai_async_client( client = get_openai_async_client(
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
) )
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:" safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
safety_check_message = None safety_check_message = None
@@ -80,7 +80,7 @@ class OpenAIOperatorAgent(OperatorAgent):
] ]
if is_none_or_empty(self.messages): if is_none_or_empty(self.messages):
self.messages = [AgentMessage(role="user", content=query)] self.messages = [AgentMessage(role="user", content=self.query)]
messages_for_api = self._format_message_for_api(self.messages) messages_for_api = self._format_message_for_api(self.messages)
response: Response = await client.responses.create( response: Response = await client.responses.create(
@@ -168,7 +168,7 @@ class OpenAIOperatorAgent(OperatorAgent):
action_results.append( action_results.append(
{ {
"type": f"{block.type}_output", "type": f"{block.type}_output",
"output": content, # Updated by environment step "output": content, # Updated after environment step
"call_id": last_call_id, "call_id": last_call_id,
} }
) )
@@ -181,10 +181,8 @@ class OpenAIOperatorAgent(OperatorAgent):
rendered_response=rendered_response, rendered_response=rendered_response,
) )
def add_action_results( def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None:
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None if not agent_action.action_results:
) -> None:
if not agent_action.action_results and not summarize_prompt:
return return
# Update action results with results of applying suggested actions on the environment # Update action results with results of applying suggested actions on the environment
@@ -209,11 +207,7 @@ class OpenAIOperatorAgent(OperatorAgent):
# Add text data # Add text data
action_result["output"] = result_content action_result["output"] = result_content
if agent_action.action_results:
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)] self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
# Append summarize prompt as a user message after tool results
if summarize_prompt:
self.messages += [AgentMessage(role="user", content=summarize_prompt)]
def _format_message_for_api(self, messages: list[AgentMessage]) -> list: def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
"""Format the message for OpenAI API.""" """Format the message for OpenAI API."""