mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Set operator query on init. Pass summarize prompt to summarize func
The initial user query isn't updated during an operator run. So set it when initializing the operator agent. Instead of passing it on every call to act. Pass summarize prompt directly to the summarize function. Let it construct the summarize message to query vision model with. Previously it was being passed to the add_action_results func as previous implementation that did not use a separate summarize func. Also rename chat_model to vision_model for a more pertinent var name. These changes make the code cleaner and implementation more readable.
This commit is contained in:
@@ -47,10 +47,10 @@ async def operate_browser(
|
||||
# Initialize Agent
|
||||
max_iterations = 40 # TODO: Configurable?
|
||||
operator_agent: OperatorAgent
|
||||
if chat_model.name.startswith("gpt-"):
|
||||
operator_agent = OpenAIOperatorAgent(chat_model, max_iterations, tracer)
|
||||
elif chat_model.name.startswith("claude-"):
|
||||
operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer)
|
||||
if reasoning_model.name.startswith("gpt-"):
|
||||
operator_agent = OpenAIOperatorAgent(query, reasoning_model, max_iterations, tracer)
|
||||
elif reasoning_model.name.startswith("claude-"):
|
||||
operator_agent = AnthropicOperatorAgent(query, reasoning_model, max_iterations, tracer)
|
||||
else:
|
||||
grounding_model_name = "ui-tars-1.5-7b"
|
||||
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
|
||||
@@ -60,7 +60,7 @@ async def operate_browser(
|
||||
or grounding_model.model_type != ChatModel.ModelType.OPENAI
|
||||
):
|
||||
raise ValueError("No supported visual grounding model for binary operator agent found.")
|
||||
operator_agent = BinaryOperatorAgent(reasoning_model, grounding_model, max_iterations, tracer)
|
||||
operator_agent = BinaryOperatorAgent(query, reasoning_model, grounding_model, max_iterations, tracer)
|
||||
|
||||
# Initialize Environment
|
||||
if send_status_func:
|
||||
@@ -87,7 +87,7 @@ async def operate_browser(
|
||||
browser_state = await environment.get_state()
|
||||
|
||||
# 2. Agent decides action(s)
|
||||
agent_result = await operator_agent.act(query, browser_state)
|
||||
agent_result = await operator_agent.act(browser_state)
|
||||
|
||||
# Render status update
|
||||
rendered_response = agent_result.rendered_response
|
||||
@@ -118,8 +118,8 @@ async def operate_browser(
|
||||
break
|
||||
if task_completed or trigger_iteration_limit:
|
||||
# Summarize results of operator run on last iteration
|
||||
operator_agent.add_action_results(env_steps, agent_result, summarize_prompt)
|
||||
summary_message = await operator_agent.summarize(query, browser_state)
|
||||
operator_agent.add_action_results(env_steps, agent_result)
|
||||
summary_message = await operator_agent.summarize(summarize_prompt, browser_state)
|
||||
logger.info(f"Task completed: {task_completed}, Iteration limit: {trigger_iteration_limit}")
|
||||
break
|
||||
|
||||
|
||||
@@ -20,9 +20,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Anthropic Operator Agent ---
|
||||
class AnthropicOperatorAgent(OperatorAgent):
|
||||
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||
client = get_anthropic_async_client(
|
||||
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url
|
||||
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
|
||||
)
|
||||
tool_version = "2025-01-24"
|
||||
betas = [f"computer-use-{tool_version}", "token-efficient-tools-2025-02-19"]
|
||||
@@ -51,7 +51,7 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
</IMPORTANT>
|
||||
"""
|
||||
if is_none_or_empty(self.messages):
|
||||
self.messages = [AgentMessage(role="user", content=query)]
|
||||
self.messages = [AgentMessage(role="user", content=self.query)]
|
||||
|
||||
tools = [
|
||||
{
|
||||
@@ -77,13 +77,13 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
]
|
||||
|
||||
thinking = {"type": "disabled"}
|
||||
if self.chat_model.name.startswith("claude-3-7"):
|
||||
if self.vision_model.name.startswith("claude-3-7"):
|
||||
thinking = {"type": "enabled", "budget_tokens": 1024}
|
||||
|
||||
messages_for_api = self._format_message_for_api(self.messages)
|
||||
response = await client.beta.messages.create(
|
||||
messages=messages_for_api,
|
||||
model=self.chat_model.name,
|
||||
model=self.vision_model.name,
|
||||
system=system_prompt,
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
@@ -187,8 +187,8 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
"content": None, # Updated by environment step
|
||||
"is_error": False, # Updated by environment step
|
||||
"content": None, # Updated after environment step
|
||||
"is_error": False, # Updated after environment step
|
||||
}
|
||||
)
|
||||
|
||||
@@ -206,13 +206,9 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
rendered_response=rendered_response,
|
||||
)
|
||||
|
||||
def add_action_results(
|
||||
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
|
||||
):
|
||||
if not agent_action.action_results and not summarize_prompt:
|
||||
def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult):
|
||||
if not agent_action.action_results:
|
||||
return
|
||||
elif not agent_action.action_results:
|
||||
agent_action.action_results = []
|
||||
|
||||
# Update action results with results of applying suggested actions on the environment
|
||||
for idx, env_step in enumerate(env_steps):
|
||||
@@ -236,10 +232,6 @@ class AnthropicOperatorAgent(OperatorAgent):
|
||||
if env_step.error:
|
||||
action_result["is_error"] = True
|
||||
|
||||
# If summarize prompt provided, append as text within the tool results user message
|
||||
if summarize_prompt:
|
||||
agent_action.action_results.append({"type": "text", "text": summarize_prompt})
|
||||
|
||||
# Append tool results to the message history
|
||||
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
||||
|
||||
|
||||
@@ -25,26 +25,26 @@ class AgentMessage(BaseModel):
|
||||
|
||||
|
||||
class OperatorAgent(ABC):
|
||||
def __init__(self, chat_model: ChatModel, max_iterations: int, tracer: dict):
|
||||
self.chat_model = chat_model
|
||||
def __init__(self, query: str, vision_model: ChatModel, max_iterations: int, tracer: dict):
|
||||
self.query = query
|
||||
self.vision_model = vision_model
|
||||
self.max_iterations = max_iterations
|
||||
self.tracer = tracer
|
||||
self.messages: List[AgentMessage] = []
|
||||
|
||||
@abstractmethod
|
||||
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_action_results(
|
||||
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
|
||||
) -> None:
|
||||
def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None:
|
||||
"""Track results of agent actions on the environment."""
|
||||
pass
|
||||
|
||||
async def summarize(self, query: str, current_state: EnvState) -> str:
|
||||
async def summarize(self, summarize_prompt: str, current_state: EnvState) -> str:
|
||||
"""Summarize the agent's actions and results."""
|
||||
await self.act(query, current_state)
|
||||
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
|
||||
await self.act(current_state)
|
||||
if not self.messages:
|
||||
return "No actions to summarize."
|
||||
return self.compile_response(self.messages[-1].content)
|
||||
@@ -63,12 +63,12 @@ class OperatorAgent(ABC):
|
||||
|
||||
def _update_usage(self, input_tokens: int, output_tokens: int, cache_read: int = 0, cache_write: int = 0):
|
||||
self.tracer["usage"] = get_chat_usage_metrics(
|
||||
self.chat_model.name, input_tokens, output_tokens, cache_read, cache_write, usage=self.tracer.get("usage")
|
||||
self.vision_model.name, input_tokens, output_tokens, cache_read, cache_write, usage=self.tracer.get("usage")
|
||||
)
|
||||
logger.debug(f"Operator usage by {self.chat_model.model_type}: {self.tracer['usage']}")
|
||||
logger.debug(f"Operator usage by {self.vision_model.model_type}: {self.tracer['usage']}")
|
||||
|
||||
def _commit_trace(self):
|
||||
self.tracer["chat_model"] = self.chat_model.name
|
||||
self.tracer["chat_model"] = self.vision_model.name
|
||||
if is_promptrace_enabled() and len(self.messages) > 1:
|
||||
compiled_messages = [
|
||||
AgentMessage(role=msg.role, content=self.compile_response(msg.content)) for msg in self.messages
|
||||
|
||||
@@ -3,7 +3,6 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
@@ -18,7 +17,6 @@ from khoj.processor.operator.operator_environment_base import EnvState, EnvStepR
|
||||
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||
from khoj.utils.helpers import (
|
||||
convert_image_to_png,
|
||||
get_chat_usage_metrics,
|
||||
get_openai_async_client,
|
||||
is_none_or_empty,
|
||||
)
|
||||
@@ -29,34 +27,35 @@ logger = logging.getLogger(__name__)
|
||||
# --- Binary Operator Agent ---
|
||||
class BinaryOperatorAgent(OperatorAgent):
|
||||
"""
|
||||
An OperatorAgent that uses two LLMs (OpenAI compatible):
|
||||
1. Vision LLM: Determines the next high-level action based on the visual state.
|
||||
An OperatorAgent that uses two LLMs:
|
||||
1. Reasoning LLM: Determines the next high-level action based on the objective and current visual reasoning trajectory.
|
||||
2. Grounding LLM: Converts the high-level action into specific, executable browser actions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_chat_model: ChatModel,
|
||||
grounding_chat_model: ChatModel, # Assuming a second model is provided/configured
|
||||
query: str,
|
||||
reasoning_model: ChatModel,
|
||||
grounding_model: ChatModel,
|
||||
max_iterations: int,
|
||||
tracer: dict,
|
||||
):
|
||||
super().__init__(vision_chat_model, max_iterations, tracer) # Use vision model for primary tracking
|
||||
self.vision_chat_model = vision_chat_model
|
||||
self.grounding_chat_model = grounding_chat_model
|
||||
# Initialize OpenAI clients
|
||||
self.grounding_client: AsyncOpenAI = get_openai_async_client(
|
||||
grounding_chat_model.ai_model_api.api_key, grounding_chat_model.ai_model_api.api_base_url
|
||||
super().__init__(query, reasoning_model, max_iterations, tracer) # Use reasoning model for primary tracking
|
||||
self.reasoning_model = reasoning_model
|
||||
self.grounding_model = grounding_model
|
||||
# Initialize openai api compatible client for grounding model
|
||||
self.grounding_client = get_openai_async_client(
|
||||
grounding_model.ai_model_api.api_key, grounding_model.ai_model_api.api_base_url
|
||||
)
|
||||
|
||||
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||
"""
|
||||
Uses a two-step LLM process to determine and structure the next action.
|
||||
"""
|
||||
self._commit_trace() # Commit trace before next action
|
||||
|
||||
# --- Step 1: Reasoning LLM determines high-level action ---
|
||||
reasoner_response = await self.act_reason(query, current_state)
|
||||
reasoner_response = await self.act_reason(current_state)
|
||||
natural_language_action = reasoner_response["message"]
|
||||
if reasoner_response["type"] == "error":
|
||||
logger.error(natural_language_action)
|
||||
@@ -75,7 +74,7 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
# --- Step 2: Grounding LLM converts NL action to structured action ---
|
||||
return await self.act_ground(natural_language_action, current_state)
|
||||
|
||||
async def act_reason(self, query: str, current_state: EnvState) -> dict[str, str]:
|
||||
async def act_reason(self, current_state: EnvState) -> dict[str, str]:
|
||||
"""
|
||||
Uses the reasoning LLM to determine the next high-level action based on the operation trajectory.
|
||||
"""
|
||||
@@ -118,12 +117,12 @@ Focus on the visual action and provide all necessary context.
|
||||
""".strip()
|
||||
|
||||
if is_none_or_empty(self.messages):
|
||||
query_text = f"**Main Objective**: {query}"
|
||||
query_text = f"**Main Objective**: {self.query}"
|
||||
query_screenshot = [f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}"]
|
||||
first_message_content = construct_structured_message(
|
||||
message=query_text,
|
||||
images=query_screenshot,
|
||||
model_type=self.vision_chat_model.model_type,
|
||||
model_type=self.reasoning_model.model_type,
|
||||
vision_enabled=True,
|
||||
)
|
||||
current_message = AgentMessage(role="user", content=first_message_content)
|
||||
@@ -140,7 +139,7 @@ Focus on the visual action and provide all necessary context.
|
||||
query_images=query_screenshot,
|
||||
system_message=reasoning_system_prompt,
|
||||
conversation_log=visual_reasoner_history,
|
||||
agent_chat_model=self.vision_chat_model,
|
||||
agent_chat_model=self.reasoning_model,
|
||||
tracer=self.tracer,
|
||||
)
|
||||
self.messages.append(current_message)
|
||||
@@ -371,7 +370,7 @@ back() # Use this to go back to the previous page.
|
||||
|
||||
try:
|
||||
grounding_response: ChatCompletion = await self.grounding_client.chat.completions.create(
|
||||
model=self.grounding_chat_model.name,
|
||||
model=self.grounding_model.name,
|
||||
messages=grounding_messages_for_api,
|
||||
tools=grounding_tools,
|
||||
tool_choice="required",
|
||||
@@ -465,14 +464,12 @@ back() # Use this to go back to the previous page.
|
||||
rendered_response=rendered_response,
|
||||
)
|
||||
|
||||
def add_action_results(
|
||||
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
|
||||
) -> None:
|
||||
def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None:
|
||||
"""
|
||||
Adds the results of executed actions back into the message history,
|
||||
formatted for the next OpenAI vision LLM call.
|
||||
"""
|
||||
if not agent_action.action_results and not summarize_prompt:
|
||||
if not agent_action.action_results:
|
||||
return
|
||||
|
||||
tool_outputs = []
|
||||
@@ -493,44 +490,38 @@ back() # Use this to go back to the previous page.
|
||||
tool_output_content = construct_structured_message(
|
||||
message=tool_outputs_str,
|
||||
images=[formatted_screenshot],
|
||||
model_type=self.vision_chat_model.model_type,
|
||||
model_type=self.reasoning_model.model_type,
|
||||
vision_enabled=True,
|
||||
)
|
||||
self.messages.append(AgentMessage(role="environment", content=tool_output_content))
|
||||
|
||||
# Append summarize prompt if provided
|
||||
if summarize_prompt:
|
||||
self.messages.append(AgentMessage(role="user", content=summarize_prompt))
|
||||
|
||||
async def summarize(self, query: str, env_state: EnvState) -> str:
|
||||
# Construct vision LLM input following OpenAI format
|
||||
async def summarize(self, summarize_prompt: str, env_state: EnvState) -> str:
|
||||
conversation_history = self._format_message_for_api(self.messages)
|
||||
try:
|
||||
summary = await send_message_to_model_wrapper(
|
||||
query=query,
|
||||
query=summarize_prompt,
|
||||
conversation_log=conversation_history,
|
||||
agent_chat_model=self.vision_chat_model,
|
||||
agent_chat_model=self.reasoning_model,
|
||||
tracer=self.tracer,
|
||||
)
|
||||
|
||||
# Return last action message if no summary
|
||||
# Set summary to last action message
|
||||
if not summary:
|
||||
return self.compile_response(self.messages[-1].content) # Compile the last action message
|
||||
|
||||
# Append summary messages to history
|
||||
trigger_summary = AgentMessage(role="user", content=query)
|
||||
summary_message = AgentMessage(role="assistant", content=summary)
|
||||
self.messages.extend([trigger_summary, summary_message])
|
||||
|
||||
return summary
|
||||
raise ValueError("Summary is empty.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Vision LLM for summary: {e}")
|
||||
return f"Error generating summary: {e}"
|
||||
logger.error(f"Error calling Reasoning LLM for summary: {e}")
|
||||
summary = "\n".join([self._get_message_text(msg) for msg in self.messages])
|
||||
|
||||
# Append summary messages to history
|
||||
trigger_summary = AgentMessage(role="user", content=summarize_prompt)
|
||||
summary_message = AgentMessage(role="assistant", content=summary)
|
||||
self.messages.extend([trigger_summary, summary_message])
|
||||
|
||||
return summary
|
||||
|
||||
def compile_response(self, response_content: Union[str, List, dict]) -> str:
|
||||
"""Compile response content into a string, handling OpenAI message structures."""
|
||||
if isinstance(response_content, str):
|
||||
return response_content # Simple text (e.g., initial user query, vision response)
|
||||
return response_content
|
||||
|
||||
if isinstance(response_content, dict) and response_content.get("role") == "assistant":
|
||||
# Grounding LLM response message (might contain tool calls)
|
||||
@@ -544,7 +535,7 @@ back() # Use this to go back to the previous page.
|
||||
compiled.append(
|
||||
f"**Action ({tc.get('function', {}).get('name')})**: {tc.get('function', {}).get('arguments')}"
|
||||
)
|
||||
return "\n- ".join(filter(None, compiled)) or "[Assistant Message]"
|
||||
return "\n- ".join(filter(None, compiled))
|
||||
|
||||
if isinstance(response_content, list): # Tool results list
|
||||
compiled = ["**Tool Results**:"]
|
||||
|
||||
@@ -20,9 +20,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Anthropic Operator Agent ---
|
||||
class OpenAIOperatorAgent(OperatorAgent):
|
||||
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
||||
async def act(self, current_state: EnvState) -> AgentActResult:
|
||||
client = get_openai_async_client(
|
||||
self.chat_model.ai_model_api.api_key, self.chat_model.ai_model_api.api_base_url
|
||||
self.vision_model.ai_model_api.api_key, self.vision_model.ai_model_api.api_base_url
|
||||
)
|
||||
safety_check_prefix = "Say 'continue' after resolving the following safety checks to proceed:"
|
||||
safety_check_message = None
|
||||
@@ -80,7 +80,7 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
]
|
||||
|
||||
if is_none_or_empty(self.messages):
|
||||
self.messages = [AgentMessage(role="user", content=query)]
|
||||
self.messages = [AgentMessage(role="user", content=self.query)]
|
||||
|
||||
messages_for_api = self._format_message_for_api(self.messages)
|
||||
response: Response = await client.responses.create(
|
||||
@@ -168,7 +168,7 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
action_results.append(
|
||||
{
|
||||
"type": f"{block.type}_output",
|
||||
"output": content, # Updated by environment step
|
||||
"output": content, # Updated after environment step
|
||||
"call_id": last_call_id,
|
||||
}
|
||||
)
|
||||
@@ -181,10 +181,8 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
rendered_response=rendered_response,
|
||||
)
|
||||
|
||||
def add_action_results(
|
||||
self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None
|
||||
) -> None:
|
||||
if not agent_action.action_results and not summarize_prompt:
|
||||
def add_action_results(self, env_steps: list[EnvStepResult], agent_action: AgentActResult) -> None:
|
||||
if not agent_action.action_results:
|
||||
return
|
||||
|
||||
# Update action results with results of applying suggested actions on the environment
|
||||
@@ -209,11 +207,7 @@ class OpenAIOperatorAgent(OperatorAgent):
|
||||
# Add text data
|
||||
action_result["output"] = result_content
|
||||
|
||||
if agent_action.action_results:
|
||||
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
||||
# Append summarize prompt as a user message after tool results
|
||||
if summarize_prompt:
|
||||
self.messages += [AgentMessage(role="user", content=summarize_prompt)]
|
||||
self.messages += [AgentMessage(role="environment", content=agent_action.action_results)]
|
||||
|
||||
def _format_message_for_api(self, messages: list[AgentMessage]) -> list:
|
||||
"""Format the message for OpenAI API."""
|
||||
|
||||
Reference in New Issue
Block a user