From 680c2261371484505100d7ede69d6a23eb7615f1 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 7 May 2025 19:25:48 -0600 Subject: [PATCH] Use any supported vision model as reasoner for binary operator agent --- .../processor/operator/operate_browser.py | 3 - .../operator/operator_agent_binary.py | 165 +++++++----------- 2 files changed, 66 insertions(+), 102 deletions(-) diff --git a/src/khoj/processor/operator/operate_browser.py b/src/khoj/processor/operator/operate_browser.py index 6de8e760..23bc2123 100644 --- a/src/khoj/processor/operator/operate_browser.py +++ b/src/khoj/processor/operator/operate_browser.py @@ -53,7 +53,6 @@ async def operate_browser( operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer) else: grounding_model_name = "ui-tars-1.5-7b" - reasoning_model = await ConversationAdapters.aget_chat_model_by_name(reasoning_model.name) grounding_model = await ConversationAdapters.aget_chat_model_by_name(grounding_model_name) if ( not grounding_model @@ -61,8 +60,6 @@ async def operate_browser( or grounding_model.model_type != ChatModel.ModelType.OPENAI ): raise ValueError("No supported visual grounding model for binary operator agent found.") - if not reasoning_model or not reasoning_model.vision_enabled: - raise ValueError("No supported visual reasoning model for binary operator agent found.") operator_agent = BinaryOperatorAgent(reasoning_model, grounding_model, max_iterations, tracer) # Initialize Environment diff --git a/src/khoj/processor/operator/operator_agent_binary.py b/src/khoj/processor/operator/operator_agent_binary.py index 112b074b..487d70cf 100644 --- a/src/khoj/processor/operator/operator_agent_binary.py +++ b/src/khoj/processor/operator/operator_agent_binary.py @@ -7,6 +7,7 @@ from openai import AsyncOpenAI from openai.types.chat import ChatCompletion from khoj.database.models import ChatModel +from khoj.processor.conversation.utils import construct_structured_message from khoj.processor.operator.operator_actions import * from khoj.processor.operator.operator_agent_base import ( AgentActResult, @@ -14,6 +15,7 @@ from khoj.processor.operator.operator_agent_base import ( OperatorAgent, ) from khoj.processor.operator.operator_environment_base import EnvState, EnvStepResult +from khoj.routers.helpers import send_message_to_model_wrapper from khoj.utils.helpers import ( convert_image_to_png, get_chat_usage_metrics, @@ -43,14 +45,9 @@ class BinaryOperatorAgent(OperatorAgent): self.vision_chat_model = vision_chat_model self.grounding_chat_model = grounding_chat_model # Initialize OpenAI clients - self.vision_client: AsyncOpenAI = get_openai_async_client( - vision_chat_model.ai_model_api.api_key, vision_chat_model.ai_model_api.api_base_url - ) self.grounding_client: AsyncOpenAI = get_openai_async_client( grounding_chat_model.ai_model_api.api_key, grounding_chat_model.ai_model_api.api_base_url ) - self.vision_usage = {} - self.grounding_usage = {} async def act(self, query: str, current_state: EnvState) -> AgentActResult: """ @@ -115,43 +112,37 @@ Focus on the visual action and provide all necessary context. """.strip() if is_none_or_empty(self.messages): - self.messages = [ - AgentMessage(role="system", content=vision_system_prompt), - AgentMessage( - role="user", - content=[ - { - "type": "text", - "text": query, - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}", - "detail": "high", - }, - }, - ], - ), - ] - # Construct vision LLM input following OpenAI format - vision_messages_for_api = self._format_message_for_api(self.messages) # Get history - try: - vision_response: ChatCompletion = await self.vision_client.chat.completions.create( - model=self.vision_chat_model.name, - messages=vision_messages_for_api, - # max_tokens=250, # Allow for more detailed description - temperature=1.0, + query_text = query + query_screenshot = [f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}"] + first_message_content = construct_structured_message( + message=query, + images=query_screenshot, + model_type=self.vision_chat_model.model_type, + vision_enabled=True, ) - logger.debug(f"Vision LLM response: {vision_response.model_dump_json()}") - natural_language_action = vision_response.choices[0].message.content + current_message = AgentMessage(role="user", content=first_message_content) + else: + current_message = self.messages.pop() + query_text = self._get_message_text(current_message) + query_screenshot = self._get_message_images(current_message) + + # Construct input for visual reasoner history + visual_reasoner_history = self._format_message_for_api(self.messages) + try: + natural_language_action = await send_message_to_model_wrapper( + query=query_text, + query_images=query_screenshot, + system_message=vision_system_prompt, + conversation_log=visual_reasoner_history, + agent_chat_model=self.vision_chat_model, + tracer=self.tracer, + ) + self.messages.append(current_message) self.messages.append(AgentMessage(role="assistant", content=natural_language_action)) if natural_language_action == "DONE": return {"type": "done", "message": "Completed task."} - # Update usage for vision model - # self._update_vision_usage(vision_response.usage.prompt_tokens, vision_response.usage.completion_tokens) logger.info(f"Vision LLM suggested action: {natural_language_action}") except Exception as e: @@ -468,8 +459,8 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par logger.warning("Grounding LLM did not produce a tool call.") rendered_response = f"**Thought (Vision)**: {natural_language_action}\n- **Response (Grounding)**: {grounding_message.content or '[No tool call]'}" - # Update usage for grounding model - # self._update_grounding_usage(grounding_response.usage.prompt_tokens, grounding_response.usage.completion_tokens) + # Update usage by grounding model + self._update_usage(grounding_response.usage.prompt_tokens, grounding_response.usage.completion_tokens) except Exception as e: logger.error(f"Error calling Grounding LLM: {e}") rendered_response = ( @@ -503,20 +494,15 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par # Append tool results message to history if tool_outputs: - tool_output_strs = "\n".join([f" - {idx}: {str(item)}" for idx, item in enumerate(tool_outputs)]) - tool_output_content = [ - { - "type": "text", - "text": f"**Action Results**:\n{tool_output_strs}", - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{convert_image_to_png(env_step.screenshot_base64)}", - "detail": "high", - }, - }, - ] + tool_outputs_list = "\n".join([f"- {idx}: {str(item)}" for idx, item in enumerate(tool_outputs)]) + tool_outputs_str = "**Action Results**:\n" + tool_outputs_list + formatted_screenshot = f"data:image/png;base64,{convert_image_to_png(env_step.screenshot_base64)}" + tool_output_content = construct_structured_message( + message=tool_outputs_str, + images=[formatted_screenshot], + model_type=self.vision_chat_model.model_type, + vision_enabled=True, + ) self.messages.append(AgentMessage(role="environment", content=tool_output_content)) # Append summarize prompt if provided @@ -525,23 +511,21 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par async def summarize(self, query: str, env_state: EnvState) -> str: # Construct vision LLM input following OpenAI format - trigger_summary = AgentMessage(role="user", content=query) - vision_messages_for_api = self._format_message_for_api(self.messages + [trigger_summary]) + conversation_history = self._format_message_for_api(self.messages) try: - summary_response: ChatCompletion = await self.vision_client.chat.completions.create( - model=self.vision_chat_model.name, - messages=vision_messages_for_api, - # max_tokens=250, # Allow for more detailed description - temperature=1.0, + summary = await send_message_to_model_wrapper( + query=query, + conversation_log=conversation_history, + agent_chat_model=self.vision_chat_model, + tracer=self.tracer, ) - logger.debug(f"Vision LLM summary response: {summary_response.model_dump_json()}") - summary = summary_response.choices[0].message.content # Return last action message if no summary if not summary: return self.compile_response(self.messages[-1].content) # Compile the last action message # Append summary messages to history + trigger_summary = AgentMessage(role="user", content=query) summary_message = AgentMessage(role="assistant", content=summary) self.messages.extend([trigger_summary, summary_message]) @@ -587,46 +571,29 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par # For now, rely on the structure built during the 'act' phase. return response # The rendered_response is already built in act() + def _get_message_text(self, message: AgentMessage) -> str: + if isinstance(message.content, list): + return "\n".join([item["text"] for item in message.content if item["type"] == "text"]) + return message.content + + def _get_message_images(self, message: AgentMessage) -> List[str]: + images = [] + if isinstance(message.content, list): + images = [item["image_url"]["url"] for item in message.content if item["type"] == "image_url"] + return images + def _format_message_for_api(self, messages: list[AgentMessage]) -> List[dict]: - """Format message history for OpenAI API calls.""" - formatted_messages = [] - for message in messages: - role = message.role - content = message.content - - if role == "environment": # Handle action results - formatted_messages.append({"role": "user", "content": content}) - else: - formatted_messages.append({"role": role, "content": content}) - return formatted_messages - - def _update_vision_usage(self, input_tokens: int, output_tokens: int): - self.vision_usage = get_chat_usage_metrics( - self.vision_chat_model.name, input_tokens, output_tokens, usage=self.vision_usage - ) - self._combine_usage() - - def _update_grounding_usage(self, input_tokens: int, output_tokens: int): - self.grounding_usage = get_chat_usage_metrics( - self.grounding_chat_model.name, input_tokens, output_tokens, usage=self.grounding_usage - ) - self._combine_usage() - - def _combine_usage(self): - """Combine usage from both models into the main tracer.""" - combined = {} - for usage_dict in [self.vision_usage, self.grounding_usage]: - for model, metrics in usage_dict.items(): - if model not in combined: - combined[model] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} - combined[model]["input_tokens"] += metrics.get("input_tokens", 0) - combined[model]["output_tokens"] += metrics.get("output_tokens", 0) - combined[model]["total_tokens"] += metrics.get("total_tokens", 0) - self.tracer["usage"] = combined - logger.debug(f"Combined Operator usage: {self.tracer['usage']}") + """Format operator agent messages into the Khoj conversation history format.""" + formatted_messages = [ + { + "message": self._get_message_text(message), + "images": self._get_message_images(message), + "by": "you" if message.role in ["user", "environment"] else message.role, + } + for message in messages + ] + return {"chat": formatted_messages} def reset(self): """Reset the agent state.""" super().reset() - self.vision_usage = {} - self.grounding_usage = {}