Use any supported vision model as reasoner for binary operator agent

This commit is contained in:
Debanjum
2025-05-07 19:25:48 -06:00
parent 3839d83b90
commit 680c226137
2 changed files with 66 additions and 102 deletions

View File

@@ -53,7 +53,6 @@ async def operate_browser(
operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer)
else:
grounding_model_name = "ui-tars-1.5-7b"
reasoning_model = await ConversationAdapters.aget_chat_model_by_name(reasoning_model.name)
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
if (
not grounding_model
@@ -61,8 +60,6 @@ async def operate_browser(
or grounding_model.model_type != ChatModel.ModelType.OPENAI
):
raise ValueError("No supported visual grounding model for binary operator agent found.")
if not reasoning_model or not reasoning_model.vision_enabled:
raise ValueError("No supported visual reasoning model for binary operator agent found.")
operator_agent = BinaryOperatorAgent(reasoning_model, grounding_model, max_iterations, tracer)
# Initialize Environment

View File

@@ -7,6 +7,7 @@ from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion
from khoj.database.models import ChatModel
from khoj.processor.conversation.utils import construct_structured_message
from khoj.processor.operator.operator_actions import *
from khoj.processor.operator.operator_agent_base import (
AgentActResult,
@@ -14,6 +15,7 @@ from khoj.processor.operator.operator_agent_base import (
OperatorAgent,
)
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import (
convert_image_to_png,
get_chat_usage_metrics,
@@ -43,14 +45,9 @@ class BinaryOperatorAgent(OperatorAgent):
self.vision_chat_model = vision_chat_model
self.grounding_chat_model = grounding_chat_model
# Initialize OpenAI clients
self.vision_client: AsyncOpenAI = get_openai_async_client(
vision_chat_model.ai_model_api.api_key, vision_chat_model.ai_model_api.api_base_url
)
self.grounding_client: AsyncOpenAI = get_openai_async_client(
grounding_chat_model.ai_model_api.api_key, grounding_chat_model.ai_model_api.api_base_url
)
self.vision_usage = {}
self.grounding_usage = {}
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
"""
@@ -115,43 +112,37 @@ Focus on the visual action and provide all necessary context.
""".strip()
if is_none_or_empty(self.messages):
self.messages = [
AgentMessage(role="system", content=vision_system_prompt),
AgentMessage(
role="user",
content=[
{
"type": "text",
"text": query,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}",
"detail": "high",
},
},
],
),
]
# Construct vision LLM input following OpenAI format
vision_messages_for_api = self._format_message_for_api(self.messages) # Get history
try:
vision_response: ChatCompletion = await self.vision_client.chat.completions.create(
model=self.vision_chat_model.name,
messages=vision_messages_for_api,
# max_tokens=250, # Allow for more detailed description
temperature=1.0,
query_text = query
query_screenshot = [f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}"]
first_message_content = construct_structured_message(
message=query,
images=query_screenshot,
model_type=self.vision_chat_model.model_type,
vision_enabled=True,
)
logger.debug(f"Vision LLM response: {vision_response.model_dump_json()}")
natural_language_action = vision_response.choices[0].message.content
current_message = AgentMessage(role="user", content=first_message_content)
else:
current_message = self.messages.pop()
query_text = self._get_message_text(current_message)
query_screenshot = self._get_message_images(current_message)
# Construct input for visual reasoner history
visual_reasoner_history = self._format_message_for_api(self.messages)
try:
natural_language_action = await send_message_to_model_wrapper(
query=query_text,
query_images=query_screenshot,
system_message=vision_system_prompt,
conversation_log=visual_reasoner_history,
agent_chat_model=self.vision_chat_model,
tracer=self.tracer,
)
self.messages.append(current_message)
self.messages.append(AgentMessage(role="assistant", content=natural_language_action))
if natural_language_action == "DONE":
return {"type": "done", "message": "Completed task."}
# Update usage for vision model
# self._update_vision_usage(vision_response.usage.prompt_tokens, vision_response.usage.completion_tokens)
logger.info(f"Vision LLM suggested action: {natural_language_action}")
except Exception as e:
@@ -468,8 +459,8 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
logger.warning("Grounding LLM did not produce a tool call.")
rendered_response = f"**Thought (Vision)**: {natural_language_action}\n- **Response (Grounding)**: {grounding_message.content or '[No tool call]'}"
# Update usage for grounding model
# self._update_grounding_usage(grounding_response.usage.prompt_tokens, grounding_response.usage.completion_tokens)
# Update usage by grounding model
self._update_usage(grounding_response.usage.prompt_tokens, grounding_response.usage.completion_tokens)
except Exception as e:
logger.error(f"Error calling Grounding LLM: {e}")
rendered_response = (
@@ -503,20 +494,15 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
# Append tool results message to history
if tool_outputs:
tool_output_strs = "\n".join([f" - {idx}: {str(item)}" for idx, item in enumerate(tool_outputs)])
tool_output_content = [
{
"type": "text",
"text": f"**Action Results**:\n{tool_output_strs}",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{convert_image_to_png(env_step.screenshot_base64)}",
"detail": "high",
},
},
]
tool_outputs_list = "\n".join([f"- {idx}: {str(item)}" for idx, item in enumerate(tool_outputs)])
tool_outputs_str = "**Action Results**:\n" + tool_outputs_list
formatted_screenshot = f"data:image/png;base64,{convert_image_to_png(env_step.screenshot_base64)}"
tool_output_content = construct_structured_message(
message=tool_outputs_str,
images=[formatted_screenshot],
model_type=self.vision_chat_model.model_type,
vision_enabled=True,
)
self.messages.append(AgentMessage(role="environment", content=tool_output_content))
# Append summarize prompt if provided
@@ -525,23 +511,21 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
async def summarize(self, query: str, env_state: EnvState) -> str:
# Construct vision LLM input following OpenAI format
trigger_summary = AgentMessage(role="user", content=query)
vision_messages_for_api = self._format_message_for_api(self.messages + [trigger_summary])
conversation_history = self._format_message_for_api(self.messages)
try:
summary_response: ChatCompletion = await self.vision_client.chat.completions.create(
model=self.vision_chat_model.name,
messages=vision_messages_for_api,
# max_tokens=250, # Allow for more detailed description
temperature=1.0,
summary = await send_message_to_model_wrapper(
query=query,
conversation_log=conversation_history,
agent_chat_model=self.vision_chat_model,
tracer=self.tracer,
)
logger.debug(f"Vision LLM summary response: {summary_response.model_dump_json()}")
summary = summary_response.choices[0].message.content
# Return last action message if no summary
if not summary:
return self.compile_response(self.messages[-1].content) # Compile the last action message
# Append summary messages to history
trigger_summary = AgentMessage(role="user", content=query)
summary_message = AgentMessage(role="assistant", content=summary)
self.messages.extend([trigger_summary, summary_message])
@@ -587,46 +571,29 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
# For now, rely on the structure built during the 'act' phase.
return response # The rendered_response is already built in act()
def _get_message_text(self, message: AgentMessage) -> str:
if isinstance(message.content, list):
return "\n".join([item["text"] for item in message.content if item["type"] == "text"])
return message.content
def _get_message_images(self, message: AgentMessage) -> List[str]:
images = []
if isinstance(message.content, list):
images = [item["image_url"]["url"] for item in message.content if item["type"] == "image_url"]
return images
def _format_message_for_api(self, messages: list[AgentMessage]) -> List[dict]:
"""Format message history for OpenAI API calls."""
formatted_messages = []
for message in messages:
role = message.role
content = message.content
if role == "environment": # Handle action results
formatted_messages.append({"role": "user", "content": content})
else:
formatted_messages.append({"role": role, "content": content})
return formatted_messages
def _update_vision_usage(self, input_tokens: int, output_tokens: int):
self.vision_usage = get_chat_usage_metrics(
self.vision_chat_model.name, input_tokens, output_tokens, usage=self.vision_usage
)
self._combine_usage()
def _update_grounding_usage(self, input_tokens: int, output_tokens: int):
self.grounding_usage = get_chat_usage_metrics(
self.grounding_chat_model.name, input_tokens, output_tokens, usage=self.grounding_usage
)
self._combine_usage()
def _combine_usage(self):
"""Combine usage from both models into the main tracer."""
combined = {}
for usage_dict in [self.vision_usage, self.grounding_usage]:
for model, metrics in usage_dict.items():
if model not in combined:
combined[model] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
combined[model]["input_tokens"] += metrics.get("input_tokens", 0)
combined[model]["output_tokens"] += metrics.get("output_tokens", 0)
combined[model]["total_tokens"] += metrics.get("total_tokens", 0)
self.tracer["usage"] = combined
logger.debug(f"Combined Operator usage: {self.tracer['usage']}")
"""Format operator agent messages into the Khoj conversation history format."""
formatted_messages = [
{
"message": self._get_message_text(message),
"images": self._get_message_images(message),
"by": "you" if message.role in ["user", "environment"] else message.role,
}
for message in messages
]
return {"chat": formatted_messages}
def reset(self):
"""Reset the agent state."""
super().reset()
self.vision_usage = {}
self.grounding_usage = {}