mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Use any supported vision model as reasoner for binary operator agent
This commit is contained in:
@@ -53,7 +53,6 @@ async def operate_browser(
|
|||||||
operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer)
|
operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer)
|
||||||
else:
|
else:
|
||||||
grounding_model_name = "ui-tars-1.5-7b"
|
grounding_model_name = "ui-tars-1.5-7b"
|
||||||
reasoning_model = await ConversationAdapters.aget_chat_model_by_name(reasoning_model.name)
|
|
||||||
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
|
grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name)
|
||||||
if (
|
if (
|
||||||
not grounding_model
|
not grounding_model
|
||||||
@@ -61,8 +60,6 @@ async def operate_browser(
|
|||||||
or grounding_model.model_type != ChatModel.ModelType.OPENAI
|
or grounding_model.model_type != ChatModel.ModelType.OPENAI
|
||||||
):
|
):
|
||||||
raise ValueError("No supported visual grounding model for binary operator agent found.")
|
raise ValueError("No supported visual grounding model for binary operator agent found.")
|
||||||
if not reasoning_model or not reasoning_model.vision_enabled:
|
|
||||||
raise ValueError("No supported visual reasoning model for binary operator agent found.")
|
|
||||||
operator_agent = BinaryOperatorAgent(reasoning_model, grounding_model, max_iterations, tracer)
|
operator_agent = BinaryOperatorAgent(reasoning_model, grounding_model, max_iterations, tracer)
|
||||||
|
|
||||||
# Initialize Environment
|
# Initialize Environment
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from openai import AsyncOpenAI
|
|||||||
from openai.types.chat import ChatCompletion
|
from openai.types.chat import ChatCompletion
|
||||||
|
|
||||||
from khoj.database.models import ChatModel
|
from khoj.database.models import ChatModel
|
||||||
|
from khoj.processor.conversation.utils import construct_structured_message
|
||||||
from khoj.processor.operator.operator_actions import *
|
from khoj.processor.operator.operator_actions import *
|
||||||
from khoj.processor.operator.operator_agent_base import (
|
from khoj.processor.operator.operator_agent_base import (
|
||||||
AgentActResult,
|
AgentActResult,
|
||||||
@@ -14,6 +15,7 @@ from khoj.processor.operator.operator_agent_base import (
|
|||||||
OperatorAgent,
|
OperatorAgent,
|
||||||
)
|
)
|
||||||
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
|
from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult
|
||||||
|
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
convert_image_to_png,
|
convert_image_to_png,
|
||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
@@ -43,14 +45,9 @@ class BinaryOperatorAgent(OperatorAgent):
|
|||||||
self.vision_chat_model = vision_chat_model
|
self.vision_chat_model = vision_chat_model
|
||||||
self.grounding_chat_model = grounding_chat_model
|
self.grounding_chat_model = grounding_chat_model
|
||||||
# Initialize OpenAI clients
|
# Initialize OpenAI clients
|
||||||
self.vision_client: AsyncOpenAI = get_openai_async_client(
|
|
||||||
vision_chat_model.ai_model_api.api_key, vision_chat_model.ai_model_api.api_base_url
|
|
||||||
)
|
|
||||||
self.grounding_client: AsyncOpenAI = get_openai_async_client(
|
self.grounding_client: AsyncOpenAI = get_openai_async_client(
|
||||||
grounding_chat_model.ai_model_api.api_key, grounding_chat_model.ai_model_api.api_base_url
|
grounding_chat_model.ai_model_api.api_key, grounding_chat_model.ai_model_api.api_base_url
|
||||||
)
|
)
|
||||||
self.vision_usage = {}
|
|
||||||
self.grounding_usage = {}
|
|
||||||
|
|
||||||
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
async def act(self, query: str, current_state: EnvState) -> AgentActResult:
|
||||||
"""
|
"""
|
||||||
@@ -115,43 +112,37 @@ Focus on the visual action and provide all necessary context.
|
|||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
if is_none_or_empty(self.messages):
|
if is_none_or_empty(self.messages):
|
||||||
self.messages = [
|
query_text = query
|
||||||
AgentMessage(role="system", content=vision_system_prompt),
|
query_screenshot = [f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}"]
|
||||||
AgentMessage(
|
first_message_content = construct_structured_message(
|
||||||
role="user",
|
message=query,
|
||||||
content=[
|
images=query_screenshot,
|
||||||
{
|
model_type=self.vision_chat_model.model_type,
|
||||||
"type": "text",
|
vision_enabled=True,
|
||||||
"text": query,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}",
|
|
||||||
"detail": "high",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
# Construct vision LLM input following OpenAI format
|
|
||||||
vision_messages_for_api = self._format_message_for_api(self.messages) # Get history
|
|
||||||
try:
|
|
||||||
vision_response: ChatCompletion = await self.vision_client.chat.completions.create(
|
|
||||||
model=self.vision_chat_model.name,
|
|
||||||
messages=vision_messages_for_api,
|
|
||||||
# max_tokens=250, # Allow for more detailed description
|
|
||||||
temperature=1.0,
|
|
||||||
)
|
)
|
||||||
logger.debug(f"Vision LLM response: {vision_response.model_dump_json()}")
|
current_message = AgentMessage(role="user", content=first_message_content)
|
||||||
natural_language_action = vision_response.choices[0].message.content
|
else:
|
||||||
|
current_message = self.messages.pop()
|
||||||
|
query_text = self._get_message_text(current_message)
|
||||||
|
query_screenshot = self._get_message_images(current_message)
|
||||||
|
|
||||||
|
# Construct input for visual reasoner history
|
||||||
|
visual_reasoner_history = self._format_message_for_api(self.messages)
|
||||||
|
try:
|
||||||
|
natural_language_action = await send_message_to_model_wrapper(
|
||||||
|
query=query_text,
|
||||||
|
query_images=query_screenshot,
|
||||||
|
system_message=vision_system_prompt,
|
||||||
|
conversation_log=visual_reasoner_history,
|
||||||
|
agent_chat_model=self.vision_chat_model,
|
||||||
|
tracer=self.tracer,
|
||||||
|
)
|
||||||
|
self.messages.append(current_message)
|
||||||
self.messages.append(AgentMessage(role="assistant", content=natural_language_action))
|
self.messages.append(AgentMessage(role="assistant", content=natural_language_action))
|
||||||
|
|
||||||
if natural_language_action == "DONE":
|
if natural_language_action == "DONE":
|
||||||
return {"type": "done", "message": "Completed task."}
|
return {"type": "done", "message": "Completed task."}
|
||||||
|
|
||||||
# Update usage for vision model
|
|
||||||
# self._update_vision_usage(vision_response.usage.prompt_tokens, vision_response.usage.completion_tokens)
|
|
||||||
logger.info(f"Vision LLM suggested action: {natural_language_action}")
|
logger.info(f"Vision LLM suggested action: {natural_language_action}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -468,8 +459,8 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
|
|||||||
logger.warning("Grounding LLM did not produce a tool call.")
|
logger.warning("Grounding LLM did not produce a tool call.")
|
||||||
rendered_response = f"**Thought (Vision)**: {natural_language_action}\n- **Response (Grounding)**: {grounding_message.content or '[No tool call]'}"
|
rendered_response = f"**Thought (Vision)**: {natural_language_action}\n- **Response (Grounding)**: {grounding_message.content or '[No tool call]'}"
|
||||||
|
|
||||||
# Update usage for grounding model
|
# Update usage by grounding model
|
||||||
# self._update_grounding_usage(grounding_response.usage.prompt_tokens, grounding_response.usage.completion_tokens)
|
self._update_usage(grounding_response.usage.prompt_tokens, grounding_response.usage.completion_tokens)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error calling Grounding LLM: {e}")
|
logger.error(f"Error calling Grounding LLM: {e}")
|
||||||
rendered_response = (
|
rendered_response = (
|
||||||
@@ -503,20 +494,15 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
|
|||||||
|
|
||||||
# Append tool results message to history
|
# Append tool results message to history
|
||||||
if tool_outputs:
|
if tool_outputs:
|
||||||
tool_output_strs = "\n".join([f" - {idx}: {str(item)}" for idx, item in enumerate(tool_outputs)])
|
tool_outputs_list = "\n".join([f"- {idx}: {str(item)}" for idx, item in enumerate(tool_outputs)])
|
||||||
tool_output_content = [
|
tool_outputs_str = "**Action Results**:\n" + tool_outputs_list
|
||||||
{
|
formatted_screenshot = f"data:image/png;base64,{convert_image_to_png(env_step.screenshot_base64)}"
|
||||||
"type": "text",
|
tool_output_content = construct_structured_message(
|
||||||
"text": f"**Action Results**:\n{tool_output_strs}",
|
message=tool_outputs_str,
|
||||||
},
|
images=[formatted_screenshot],
|
||||||
{
|
model_type=self.vision_chat_model.model_type,
|
||||||
"type": "image_url",
|
vision_enabled=True,
|
||||||
"image_url": {
|
)
|
||||||
"url": f"data:image/png;base64,{convert_image_to_png(env_step.screenshot_base64)}",
|
|
||||||
"detail": "high",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
self.messages.append(AgentMessage(role="environment", content=tool_output_content))
|
self.messages.append(AgentMessage(role="environment", content=tool_output_content))
|
||||||
|
|
||||||
# Append summarize prompt if provided
|
# Append summarize prompt if provided
|
||||||
@@ -525,23 +511,21 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
|
|||||||
|
|
||||||
async def summarize(self, query: str, env_state: EnvState) -> str:
|
async def summarize(self, query: str, env_state: EnvState) -> str:
|
||||||
# Construct vision LLM input following OpenAI format
|
# Construct vision LLM input following OpenAI format
|
||||||
trigger_summary = AgentMessage(role="user", content=query)
|
conversation_history = self._format_message_for_api(self.messages)
|
||||||
vision_messages_for_api = self._format_message_for_api(self.messages + [trigger_summary])
|
|
||||||
try:
|
try:
|
||||||
summary_response: ChatCompletion = await self.vision_client.chat.completions.create(
|
summary = await send_message_to_model_wrapper(
|
||||||
model=self.vision_chat_model.name,
|
query=query,
|
||||||
messages=vision_messages_for_api,
|
conversation_log=conversation_history,
|
||||||
# max_tokens=250, # Allow for more detailed description
|
agent_chat_model=self.vision_chat_model,
|
||||||
temperature=1.0,
|
tracer=self.tracer,
|
||||||
)
|
)
|
||||||
logger.debug(f"Vision LLM summary response: {summary_response.model_dump_json()}")
|
|
||||||
summary = summary_response.choices[0].message.content
|
|
||||||
|
|
||||||
# Return last action message if no summary
|
# Return last action message if no summary
|
||||||
if not summary:
|
if not summary:
|
||||||
return self.compile_response(self.messages[-1].content) # Compile the last action message
|
return self.compile_response(self.messages[-1].content) # Compile the last action message
|
||||||
|
|
||||||
# Append summary messages to history
|
# Append summary messages to history
|
||||||
|
trigger_summary = AgentMessage(role="user", content=query)
|
||||||
summary_message = AgentMessage(role="assistant", content=summary)
|
summary_message = AgentMessage(role="assistant", content=summary)
|
||||||
self.messages.extend([trigger_summary, summary_message])
|
self.messages.extend([trigger_summary, summary_message])
|
||||||
|
|
||||||
@@ -587,46 +571,29 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
|
|||||||
# For now, rely on the structure built during the 'act' phase.
|
# For now, rely on the structure built during the 'act' phase.
|
||||||
return response # The rendered_response is already built in act()
|
return response # The rendered_response is already built in act()
|
||||||
|
|
||||||
|
def _get_message_text(self, message: AgentMessage) -> str:
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
return "\n".join([item["text"] for item in message.content if item["type"] == "text"])
|
||||||
|
return message.content
|
||||||
|
|
||||||
|
def _get_message_images(self, message: AgentMessage) -> List[str]:
|
||||||
|
images = []
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
images = [item["image_url"]["url"] for item in message.content if item["type"] == "image_url"]
|
||||||
|
return images
|
||||||
|
|
||||||
def _format_message_for_api(self, messages: list[AgentMessage]) -> List[dict]:
|
def _format_message_for_api(self, messages: list[AgentMessage]) -> List[dict]:
|
||||||
"""Format message history for OpenAI API calls."""
|
"""Format operator agent messages into the Khoj conversation history format."""
|
||||||
formatted_messages = []
|
formatted_messages = [
|
||||||
for message in messages:
|
{
|
||||||
role = message.role
|
"message": self._get_message_text(message),
|
||||||
content = message.content
|
"images": self._get_message_images(message),
|
||||||
|
"by": "you" if message.role in ["user", "environment"] else message.role,
|
||||||
if role == "environment": # Handle action results
|
}
|
||||||
formatted_messages.append({"role": "user", "content": content})
|
for message in messages
|
||||||
else:
|
]
|
||||||
formatted_messages.append({"role": role, "content": content})
|
return {"chat": formatted_messages}
|
||||||
return formatted_messages
|
|
||||||
|
|
||||||
def _update_vision_usage(self, input_tokens: int, output_tokens: int):
|
|
||||||
self.vision_usage = get_chat_usage_metrics(
|
|
||||||
self.vision_chat_model.name, input_tokens, output_tokens, usage=self.vision_usage
|
|
||||||
)
|
|
||||||
self._combine_usage()
|
|
||||||
|
|
||||||
def _update_grounding_usage(self, input_tokens: int, output_tokens: int):
|
|
||||||
self.grounding_usage = get_chat_usage_metrics(
|
|
||||||
self.grounding_chat_model.name, input_tokens, output_tokens, usage=self.grounding_usage
|
|
||||||
)
|
|
||||||
self._combine_usage()
|
|
||||||
|
|
||||||
def _combine_usage(self):
|
|
||||||
"""Combine usage from both models into the main tracer."""
|
|
||||||
combined = {}
|
|
||||||
for usage_dict in [self.vision_usage, self.grounding_usage]:
|
|
||||||
for model, metrics in usage_dict.items():
|
|
||||||
if model not in combined:
|
|
||||||
combined[model] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
|
||||||
combined[model]["input_tokens"] += metrics.get("input_tokens", 0)
|
|
||||||
combined[model]["output_tokens"] += metrics.get("output_tokens", 0)
|
|
||||||
combined[model]["total_tokens"] += metrics.get("total_tokens", 0)
|
|
||||||
self.tracer["usage"] = combined
|
|
||||||
logger.debug(f"Combined Operator usage: {self.tracer['usage']}")
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the agent state."""
|
"""Reset the agent state."""
|
||||||
super().reset()
|
super().reset()
|
||||||
self.vision_usage = {}
|
|
||||||
self.grounding_usage = {}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user