mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Use any supported vision model as reasoner for binary operator agent
This commit is contained in:
@@ -53,7 +53,6 @@ async def operate_browser(
|
||||
operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer)
|
||||
else:
|
||||
grounding_model_name = "ui-tars-1.5-7b"
|
||||
reasoning_model = await ConversationAdapters.aget_chat_model_by_name(reasoning_model.name)
|
||||
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
|
||||
if (
|
||||
not grounding_model
|
||||
@@ -61,8 +60,6 @@ async def operate_browser(
|
||||
or grounding_model.model_type != ChatModel.ModelType.OPENAI
|
||||
):
|
||||
raise ValueError("No supported visual grounding model for binary operator agent found.")
|
||||
if not reasoning_model or not reasoning_model.vision_enabled:
|
||||
raise ValueError("No supported visual reasoning model for binary operator agent found.")
|
||||
operator_agent = BinaryOperatorAgent(reasoning_model, grounding_model, max_iterations, tracer)
|
||||
|
||||
# Initialize Environment
|
||||
|
||||
@@ -7,6 +7,7 @@ from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.processor.conversation.utils import construct_structured_message
|
||||
from khoj.processor.operator.operator_actions import *
|
||||
from khoj.processor.operator.operator_agent_base import (
|
||||
AgentActResult,
|
||||
@@ -14,6 +15,7 @@ from khoj.processor.operator.operator_agent_base import (
|
||||
OperatorAgent,
|
||||
)
|
||||
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
|
||||
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||
from khoj.utils.helpers import (
|
||||
convert_image_to_png,
|
||||
get_chat_usage_metrics,
|
||||
@@ -43,14 +45,9 @@ class BinaryOperatorAgent(OperatorAgent):
|
||||
self.vision_chat_model = vision_chat_model
|
||||
self.grounding_chat_model = grounding_chat_model
|
||||
# Initialize OpenAI clients
|
||||
self.vision_client: AsyncOpenAI = get_openai_async_client(
|
||||
vision_chat_model.ai_model_api.api_key, vision_chat_model.ai_model_api.api_base_url
|
||||
)
|
||||
self.grounding_client: AsyncOpenAI = get_openai_async_client(
|
||||
grounding_chat_model.ai_model_api.api_key, grounding_chat_model.ai_model_api.api_base_url
|
||||
)
|
||||
self.vision_usage = {}
|
||||
self.grounding_usage = {}
|
||||
|
||||
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
||||
"""
|
||||
@@ -115,43 +112,37 @@ Focus on the visual action and provide all necessary context.
|
||||
""".strip()
|
||||
|
||||
if is_none_or_empty(self.messages):
|
||||
self.messages = [
|
||||
AgentMessage(role="system", content=vision_system_prompt),
|
||||
AgentMessage(
|
||||
role="user",
|
||||
content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": query,
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
],
|
||||
),
|
||||
]
|
||||
# Construct vision LLM input following OpenAI format
|
||||
vision_messages_for_api = self._format_message_for_api(self.messages) # Get history
|
||||
try:
|
||||
vision_response: ChatCompletion = await self.vision_client.chat.completions.create(
|
||||
model=self.vision_chat_model.name,
|
||||
messages=vision_messages_for_api,
|
||||
# max_tokens=250, # Allow for more detailed description
|
||||
temperature=1.0,
|
||||
query_text = query
|
||||
query_screenshot = [f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}"]
|
||||
first_message_content = construct_structured_message(
|
||||
message=query,
|
||||
images=query_screenshot,
|
||||
model_type=self.vision_chat_model.model_type,
|
||||
vision_enabled=True,
|
||||
)
|
||||
logger.debug(f"Vision LLM response: {vision_response.model_dump_json()}")
|
||||
natural_language_action = vision_response.choices[0].message.content
|
||||
current_message = AgentMessage(role="user", content=first_message_content)
|
||||
else:
|
||||
current_message = self.messages.pop()
|
||||
query_text = self._get_message_text(current_message)
|
||||
query_screenshot = self._get_message_images(current_message)
|
||||
|
||||
# Construct input for visual reasoner history
|
||||
visual_reasoner_history = self._format_message_for_api(self.messages)
|
||||
try:
|
||||
natural_language_action = await send_message_to_model_wrapper(
|
||||
query=query_text,
|
||||
query_images=query_screenshot,
|
||||
system_message=vision_system_prompt,
|
||||
conversation_log=visual_reasoner_history,
|
||||
agent_chat_model=self.vision_chat_model,
|
||||
tracer=self.tracer,
|
||||
)
|
||||
self.messages.append(current_message)
|
||||
self.messages.append(AgentMessage(role="assistant", content=natural_language_action))
|
||||
|
||||
if natural_language_action == "DONE":
|
||||
return {"type": "done", "message": "Completed task."}
|
||||
|
||||
# Update usage for vision model
|
||||
# self._update_vision_usage(vision_response.usage.prompt_tokens, vision_response.usage.completion_tokens)
|
||||
logger.info(f"Vision LLM suggested action: {natural_language_action}")
|
||||
|
||||
except Exception as e:
|
||||
@@ -468,8 +459,8 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
|
||||
logger.warning("Grounding LLM did not produce a tool call.")
|
||||
rendered_response = f"**Thought (Vision)**: {natural_language_action}\n- **Response (Grounding)**: {grounding_message.content or '[No tool call]'}"
|
||||
|
||||
# Update usage for grounding model
|
||||
# self._update_grounding_usage(grounding_response.usage.prompt_tokens, grounding_response.usage.completion_tokens)
|
||||
# Update usage by grounding model
|
||||
self._update_usage(grounding_response.usage.prompt_tokens, grounding_response.usage.completion_tokens)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Grounding LLM: {e}")
|
||||
rendered_response = (
|
||||
@@ -503,20 +494,15 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
|
||||
|
||||
# Append tool results message to history
|
||||
if tool_outputs:
|
||||
tool_output_strs = "\n".join([f" - {idx}: {str(item)}" for idx, item in enumerate(tool_outputs)])
|
||||
tool_output_content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"**Action Results**:\n{tool_output_strs}",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{convert_image_to_png(env_step.screenshot_base64)}",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
]
|
||||
tool_outputs_list = "\n".join([f"- {idx}: {str(item)}" for idx, item in enumerate(tool_outputs)])
|
||||
tool_outputs_str = "**Action Results**:\n" + tool_outputs_list
|
||||
formatted_screenshot = f"data:image/png;base64,{convert_image_to_png(env_step.screenshot_base64)}"
|
||||
tool_output_content = construct_structured_message(
|
||||
message=tool_outputs_str,
|
||||
images=[formatted_screenshot],
|
||||
model_type=self.vision_chat_model.model_type,
|
||||
vision_enabled=True,
|
||||
)
|
||||
self.messages.append(AgentMessage(role="environment", content=tool_output_content))
|
||||
|
||||
# Append summarize prompt if provided
|
||||
@@ -525,23 +511,21 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
|
||||
|
||||
async def summarize(self, query: str, env_state: EnvState) -> str:
|
||||
# Construct vision LLM input following OpenAI format
|
||||
trigger_summary = AgentMessage(role="user", content=query)
|
||||
vision_messages_for_api = self._format_message_for_api(self.messages + [trigger_summary])
|
||||
conversation_history = self._format_message_for_api(self.messages)
|
||||
try:
|
||||
summary_response: ChatCompletion = await self.vision_client.chat.completions.create(
|
||||
model=self.vision_chat_model.name,
|
||||
messages=vision_messages_for_api,
|
||||
# max_tokens=250, # Allow for more detailed description
|
||||
temperature=1.0,
|
||||
summary = await send_message_to_model_wrapper(
|
||||
query=query,
|
||||
conversation_log=conversation_history,
|
||||
agent_chat_model=self.vision_chat_model,
|
||||
tracer=self.tracer,
|
||||
)
|
||||
logger.debug(f"Vision LLM summary response: {summary_response.model_dump_json()}")
|
||||
summary = summary_response.choices[0].message.content
|
||||
|
||||
# Return last action message if no summary
|
||||
if not summary:
|
||||
return self.compile_response(self.messages[-1].content) # Compile the last action message
|
||||
|
||||
# Append summary messages to history
|
||||
trigger_summary = AgentMessage(role="user", content=query)
|
||||
summary_message = AgentMessage(role="assistant", content=summary)
|
||||
self.messages.extend([trigger_summary, summary_message])
|
||||
|
||||
@@ -587,46 +571,29 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
|
||||
# For now, rely on the structure built during the 'act' phase.
|
||||
return response # The rendered_response is already built in act()
|
||||
|
||||
def _get_message_text(self, message: AgentMessage) -> str:
|
||||
if isinstance(message.content, list):
|
||||
return "\n".join([item["text"] for item in message.content if item["type"] == "text"])
|
||||
return message.content
|
||||
|
||||
def _get_message_images(self, message: AgentMessage) -> List[str]:
|
||||
images = []
|
||||
if isinstance(message.content, list):
|
||||
images = [item["image_url"]["url"] for item in message.content if item["type"] == "image_url"]
|
||||
return images
|
||||
|
||||
def _format_message_for_api(self, messages: list[AgentMessage]) -> List[dict]:
|
||||
"""Format message history for OpenAI API calls."""
|
||||
formatted_messages = []
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content = message.content
|
||||
|
||||
if role == "environment": # Handle action results
|
||||
formatted_messages.append({"role": "user", "content": content})
|
||||
else:
|
||||
formatted_messages.append({"role": role, "content": content})
|
||||
return formatted_messages
|
||||
|
||||
def _update_vision_usage(self, input_tokens: int, output_tokens: int):
|
||||
self.vision_usage = get_chat_usage_metrics(
|
||||
self.vision_chat_model.name, input_tokens, output_tokens, usage=self.vision_usage
|
||||
)
|
||||
self._combine_usage()
|
||||
|
||||
def _update_grounding_usage(self, input_tokens: int, output_tokens: int):
|
||||
self.grounding_usage = get_chat_usage_metrics(
|
||||
self.grounding_chat_model.name, input_tokens, output_tokens, usage=self.grounding_usage
|
||||
)
|
||||
self._combine_usage()
|
||||
|
||||
def _combine_usage(self):
|
||||
"""Combine usage from both models into the main tracer."""
|
||||
combined = {}
|
||||
for usage_dict in [self.vision_usage, self.grounding_usage]:
|
||||
for model, metrics in usage_dict.items():
|
||||
if model not in combined:
|
||||
combined[model] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
combined[model]["input_tokens"] += metrics.get("input_tokens", 0)
|
||||
combined[model]["output_tokens"] += metrics.get("output_tokens", 0)
|
||||
combined[model]["total_tokens"] += metrics.get("total_tokens", 0)
|
||||
self.tracer["usage"] = combined
|
||||
logger.debug(f"Combined Operator usage: {self.tracer['usage']}")
|
||||
"""Format operator agent messages into the Khoj conversation history format."""
|
||||
formatted_messages = [
|
||||
{
|
||||
"message": self._get_message_text(message),
|
||||
"images": self._get_message_images(message),
|
||||
"by": "you" if message.role in ["user", "environment"] else message.role,
|
||||
}
|
||||
for message in messages
|
||||
]
|
||||
return {"chat": formatted_messages}
|
||||
|
||||
def reset(self):
|
||||
"""Reset the agent state."""
|
||||
super().reset()
|
||||
self.vision_usage = {}
|
||||
self.grounding_usage = {}
|
||||
|
||||
Reference in New Issue
Block a user