From 833c8ed15037aef47a87f068e01dd15729551cd1 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 5 May 2025 23:22:56 -0600 Subject: [PATCH] Add a flexible operator agent using separate reasoning, grounder models - This operator works with model served over an openai compatible api - It uses separate vision models to reason and ground actions. This improves flexibility in the operator agents that can be created. We do not know need our operator agent ot rely on monolithic models to can both reason over visual data and ground their actions. We can create operator agent from 2 separate models: 1. To reason over screenshots to suggest natural language next action 2. To ground those suggestion into visually grounded actions This allows us to create fully local operators or operators combining the best visual reasoner with the best visual grounder models. --- src/khoj/database/adapters/__init__.py | 2 +- .../processor/operator/browser_operator.py | 638 +++++++++++++++++- src/khoj/utils/helpers.py | 14 + 3 files changed, 642 insertions(+), 12 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 13854681..173d087e 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1191,7 +1191,7 @@ class ConversationAdapters: async def aget_chat_model_by_name(chat_model_name: str, ai_model_api_name: str = None): if ai_model_api_name: return await ChatModel.objects.filter(name=chat_model_name, ai_model_api__name=ai_model_api_name).afirst() - return await ChatModel.objects.filter(name=chat_model_name).afirst() + return await ChatModel.objects.filter(name=chat_model_name).prefetch_related("ai_model_api").afirst() @staticmethod async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: diff --git a/src/khoj/processor/operator/browser_operator.py b/src/khoj/processor/operator/browser_operator.py index 44e94529..30c0ec7b 100644 --- a/src/khoj/processor/operator/browser_operator.py +++ b/src/khoj/processor/operator/browser_operator.py @@ -9,7 +9,9 @@ from datetime import datetime from typing import Any, Callable, List, Literal, Optional, Set, Union import requests -from anthropic.types.beta import BetaContentBlock, BetaMessage +from anthropic.types.beta import BetaContentBlock +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion from openai.types.responses import Response, ResponseOutputItem from playwright.async_api import Browser, Page, Playwright, async_playwright from pydantic import BaseModel @@ -19,6 +21,7 @@ from khoj.database.models import Agent, ChatModel, KhojUser from khoj.processor.conversation.utils import commit_conversation_trace from khoj.routers.helpers import ChatEvent from khoj.utils.helpers import ( + convert_image_to_png, convert_image_to_webp, get_anthropic_async_client, get_chat_usage_metrics, @@ -1177,6 +1180,607 @@ class AnthropicOperatorAgent(OperatorAgent): return "\n- ".join(filter(None, rendered_response)) +# --- Binary Operator Agent --- +class BinaryOperatorAgent(OperatorAgent): + """ + An OperatorAgent that uses two LLMs (OpenAI compatible): + 1. Vision LLM: Determines the next high-level action based on the visual state. + 2. Grounding LLM: Converts the high-level action into specific, executable browser actions. + """ + + def __init__( + self, + vision_chat_model: ChatModel, + grounding_chat_model: ChatModel, # Assuming a second model is provided/configured + max_iterations: int, + tracer: dict, + ): + super().__init__(vision_chat_model, max_iterations, tracer) # Use vision model for primary tracking + self.vision_chat_model = vision_chat_model + self.grounding_chat_model = grounding_chat_model + # Initialize OpenAI clients + self.vision_client: AsyncOpenAI = get_openai_async_client( + vision_chat_model.ai_model_api.api_key, vision_chat_model.ai_model_api.api_base_url + ) + self.grounding_client: AsyncOpenAI = get_openai_async_client( + grounding_chat_model.ai_model_api.api_key, grounding_chat_model.ai_model_api.api_base_url + ) + self.vision_usage = {} + self.grounding_usage = {} + + async def act(self, query: str, current_state: EnvState) -> AgentActResult: + """ + Uses a two-step LLM process to determine and structure the next action. + """ + self._commit_trace() # Commit trace before next action + + # --- Step 1: Reasoning LLM determines high-level action --- + reasoner_response = await self.act_reason(query, current_state) + natural_language_action = reasoner_response["message"] + if reasoner_response["type"] == "error": + logger.error(f"Error in reasoning LLM: {natural_language_action}") + return AgentActResult( + actions=[], + action_results=[], + rendered_response=natural_language_action, + ) + + # --- Step 2: Grounding LLM converts NL action to structured action --- + return await self.act_ground(natural_language_action, current_state) + + async def act_reason(self, query: str, current_state: EnvState) -> dict[str, str]: + """ + Uses the reasoning LLM to determine the next high-level action based on the operation trajectory. + """ + vision_system_prompt = f""" +* You are Khoj, a smart web browsing assistant. You help the user accomplish their task using a web browser. +* You will be given the user's query and screenshots of the current browser state. +* You instruct a tool AI to operate a single Chromium browser page via Playwright. +* The tool AI only has access to the current screenshot and your instructions. It uses your instructions to perform an action on the page. +* It can interact with the web browser to perform tasks like click, right click, double click, type, scroll, drag, wait, goto url, go back to previous page and take screenshots. +* It cannot access the OS or filesystem. +* Make sure you scroll down to see everything before deciding something isn't available. +* Perform web searches using DuckDuckGo. Don't use Google even if requested as the query will fail. + + + +* You are allowed upto {self.max_iterations} iterations to complete the task. +* Do not loop on wait, screenshot for too many turns without taking any action. +* Once you've verified that the task has been completed, just say "DONE" (without the quotes). Do not say anything else. + + +* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. +* The current URL is {current_state.url}. + +Now describe a single high-level action to take next to progress towards the user's goal in detail. +Focus on the visual action and provide all necessary context. + +For Example: +- 'click the blue login button located at the top right corner' +- 'scroll down the page to find the contact section' +- 'type the username example@email.com into the input field labeled Username') +""" + + if is_none_or_empty(self.messages): + self.messages = [ + ChatMessage(role="system", content=vision_system_prompt), + ChatMessage( + role="user", + content=[ + { + "type": "text", + "text": query, + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}" + }, + }, + ], + ), + ] + # Construct vision LLM input following OpenAI format + vision_messages_for_api = self._format_message_for_api(self.messages) # Get history + try: + vision_response: ChatCompletion = await self.vision_client.chat.completions.create( + model=self.vision_chat_model.name, + messages=vision_messages_for_api, + # max_tokens=250, # Allow for more detailed description + temperature=1.0, + ) + logger.debug(f"Vision LLM response: {vision_response.model_dump_json()}") + natural_language_action = vision_response.choices[0].message.content + self.messages.append(ChatMessage(role="assistant", content=natural_language_action)) + + if natural_language_action == "DONE": + return {"type": "done", "message": "Completed task."} + + # Update usage for vision model + # self._update_vision_usage(vision_response.usage.prompt_tokens, vision_response.usage.completion_tokens) + logger.info(f"Vision LLM suggested action: {natural_language_action}") + + except Exception as e: + return {"type": "error", "message": f"Error calling Vision LLM: {e}"} + + return {"type": "action", "message": natural_language_action} + + async def act_ground(self, natural_language_action: str, current_state: EnvState) -> AgentActResult: + """Uses the grounding LLM to convert the high-level action into structured browser actions.""" + actions: List[BrowserAction] = [] + action_results: List[dict] = [] + rendered_response = "No action determined." + grounding_user_prompt = f""" +You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. + +## Output Format +``` +Thought: ... +Action: ... +``` + +## Action Space + +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') +right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') +drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +hotkey(key='') +type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content. +scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +wait(duration='time') # Sleep for specified time. Default is 1s and take a screenshot to check for any changes. +goto(url='xxx') # Always use this to navigate to a specific URL. Use escape characters \\', \\", and \\n in url part to ensure we can parse the url in normal python string format. +back() # Use this to go back to the previous page. +finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. + +## Note +- Use English in `Thought` part. +- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part. + +## User Instruction +{natural_language_action} +""" + + # Define tools for the grounding LLM (OpenAI format) + grounding_tools = [ + { + "type": "function", + "function": { + "name": "click", + "description": "Click on a specific coordinate.", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer", "description": "X coordinate"}, + "y": {"type": "integer", "description": "Y coordinate"}, + "button": { + "type": "string", + "enum": ["left", "right", "middle", "wheel"], + "default": "left", + }, + "modifiers": { + "type": "string", + "description": "Optional modifier keys (e.g., 'Shift', 'Control+Alt')", + "nullable": True, + }, + }, + "required": ["x", "y"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "left_double", + "description": "Double click on a specific coordinate.", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer", "description": "X coordinate"}, + "y": {"type": "integer", "description": "Y coordinate"}, + }, + "required": ["x", "y"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "right_single", + "description": "Right click on a specific coordinate.", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer", "description": "X coordinate"}, + "y": {"type": "integer", "description": "Y coordinate"}, + }, + "required": ["x", "y"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "drag", + "description": "Perform a drag-and-drop operation along a path.", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "array", + "items": { + "type": "object", + "properties": { + "x": {"type": "integer"}, + "y": {"type": "integer"}, + }, + "required": ["x", "y"], + }, + "description": "List of points (x, y coordinates) defining the drag path.", + } + }, + "required": ["path"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "hotkey", + "description": "Press a key or key combination.", + "parameters": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": {"type": "string"}, + "description": "List of keys to press (e.g., ['Control', 'a'], ['Enter'])", + } + }, + "required": ["keys"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "type", + "description": "Type text, usually into a focused input field.", + "parameters": { + "type": "object", + "properties": {"content": {"type": "string", "description": "Text to type"}}, + "required": ["content"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "scroll", + "description": "Scroll the page.", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer", "description": "X coordinate to scroll from"}, + "y": {"type": "integer", "description": "Y coordinate to scroll from"}, + "direction": { + "type": "string", + "enum": ["up", "down", "left", "right"], + "default": "down", + }, + }, + "required": [], # None is strictly required + }, + }, + }, + { + "type": "function", + "function": { + "name": "wait", + "description": "Pause execution for a specified duration.", + "parameters": { + "type": "object", + "properties": { + "duration": {"type": "number", "description": "Duration in seconds", "default": 1.0} + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "goto", + "description": "Navigate to a specific URL.", + "parameters": { + "type": "object", + "properties": {"url": {"type": "string", "description": "Fully qualified URL"}}, + "required": ["url"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "back", + "description": "navigate back to the previous page.", + "parameters": {"type": "object", "properties": {}}, + }, + }, + { + "type": "function", + "function": { + "name": "finished", + "description": "If no further actions to take.", + "parameters": { + "type": "object", + "properties": {"content": {"type": "string", "description": "Text to type"}}, + "required": ["content"], + }, + }, + }, + ] + + # Construct grounding LLM input (using only the latest user prompt + image) + # We don't pass the full history here, as grounding depends on the *current* state + NL action + grounding_messages_for_api = [ + { + "role": "user", + "content": [ + {"type": "text", "text": grounding_user_prompt}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{convert_image_to_png(current_state.screenshot)}"}, + }, + ], + } + ] + + try: + grounding_response: ChatCompletion = await self.grounding_client.chat.completions.create( + model=self.grounding_chat_model.name, + messages=grounding_messages_for_api, + tools=grounding_tools, + tool_choice="auto", + temperature=0.0, # Grounding should be precise + max_tokens=1000, # Allow for thoughts + actions + ) + logger.debug(f"Grounding LLM response: {grounding_response.model_dump_json()}") + + grounding_message = grounding_response.choices[0].message + # Parse tool calls + if grounding_message.tool_calls: + # Start rendering with vision output + rendered_parts = [f"**Thought (Vision)**: {natural_language_action}"] + for tool_call in grounding_message.tool_calls: + function_name = tool_call.function.name + try: + arguments = json.loads(tool_call.function.arguments) + action_to_run: Optional[BrowserAction] = None + action_render_str = f"**Action ({function_name})**: {tool_call.function.arguments}" + + if function_name == "click": + action_to_run = ClickAction(**arguments) + elif function_name == "left_double": + action_to_run = DoubleClickAction(**arguments) + elif function_name == "right_single": + action_to_run = ClickAction(button="right", **arguments) + elif function_name == "type": + action_to_run = TypeAction(**arguments) + elif function_name == "scroll": + x = arguments.get("x") + y = arguments.get("y") + direction = arguments.get("direction", "down") + amount = 5 + action_to_run = ScrollAction(scroll_direction=direction, scroll_amount=amount, x=x, y=y) + elif function_name == "hotkey": + action_to_run = KeypressAction(**arguments) + elif function_name == "goto": + action_to_run = GotoAction(**arguments) + elif function_name == "back": + action_to_run = BackAction(**arguments) + elif function_name == "wait": + action_to_run = WaitAction(**arguments) + elif function_name == "screenshot": + action_to_run = ScreenshotAction(**arguments) + elif function_name == "drag": + # Need to convert list of dicts to list of Point objects + path_dicts = arguments.get("path", []) + path_points = [Point(**p) for p in path_dicts] + if path_points: + action_to_run = DragAction(path=path_points) + else: + logger.warning(f"Drag action called with empty path: {arguments}") + action_render_str += " [Skipped - empty path]" + elif function_name == "finished": + action_to_run = None + else: + logger.warning(f"Grounding LLM called unhandled tool: {function_name}") + action_render_str += " [Unhandled]" + + if action_to_run: + actions.append(action_to_run) + # Prepare action result structure (similar to OpenAIOperatorAgent) + action_results.append( + { + "type": "tool_result", + "tool_call_id": tool_call.id, + "content": None, # Updated by environment step + } + ) + rendered_parts.append(action_render_str) + except (json.JSONDecodeError, TypeError, ValueError) as arg_err: + logger.error( + f"Error parsing arguments for tool {function_name}: {arg_err} - Args: {tool_call.function.arguments}" + ) + rendered_parts.append(f"**Error**: Failed to parse arguments for {function_name}") + rendered_response = "\n- ".join(rendered_parts) + else: + # Grounding LLM responded but didn't call a tool + logger.warning("Grounding LLM did not produce a tool call.") + rendered_response = f"**Thought (Vision)**: {natural_language_action}\n- **Response (Grounding)**: {grounding_message.content or '[No tool call]'}" + + # Update usage for grounding model + # self._update_grounding_usage(grounding_response.usage.prompt_tokens, grounding_response.usage.completion_tokens) + except Exception as e: + logger.error(f"Error calling Grounding LLM: {e}") + rendered_response = ( + f"**Thought (Vision)**: {natural_language_action}\n- **Error**: Error contacting Grounding LLM: {e}" + ) + return AgentActResult( + actions=actions, + action_results=action_results, + rendered_response=rendered_response, + ) + + def add_action_results( + self, env_steps: list[EnvStepResult], agent_action: AgentActResult, summarize_prompt: str = None + ) -> None: + """ + Adds the results of executed actions back into the message history, + formatted for the next OpenAI vision LLM call. + """ + if not agent_action.action_results and not summarize_prompt: + return + + tool_outputs = [] + for idx, env_step in enumerate(env_steps): + if idx < len(agent_action.action_results): # Ensure we don't go out of bounds + result_content = env_step.error or env_step.output or "[Action completed]" + tool_outputs.append(["Took screenshot" if env_step.type == "image" else json.dumps(result_content)]) + else: + logger.warning( + f"Mismatch between env_steps ({len(env_steps)}) and action_results ({len(agent_action.action_results)})" + ) + + # Append tool results message to history + if tool_outputs: + tool_output_strs = "\n".join([f" - {idx}: {str(item)}" for idx, item in enumerate(tool_outputs)]) + tool_output_content = [ + { + "type": "text", + "text": f"**Action Results**:\n{tool_output_strs}", + }, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{convert_image_to_png(env_step.screenshot_base64)}"}, + }, + ] + self.messages.append(ChatMessage(role="environment", content=tool_output_content)) + + # Append summarize prompt if provided + if summarize_prompt: + self.messages.append(ChatMessage(role="user", content=summarize_prompt)) + + async def summarize(self, query: str, env_state: EnvState) -> str: + # Construct vision LLM input following OpenAI format + trigger_summary = ChatMessage(role="user", content=query) + vision_messages_for_api = self._format_message_for_api(self.messages + [trigger_summary]) + try: + summary_response: ChatCompletion = await self.vision_client.chat.completions.create( + model=self.vision_chat_model.name, + messages=vision_messages_for_api, + # max_tokens=250, # Allow for more detailed description + temperature=1.0, + ) + logger.debug(f"Vision LLM summary response: {summary_response.model_dump_json()}") + summary = summary_response.choices[0].message.content + + # Return last action message if no summary + if not summary: + return self.compile_response(self.messages[-1].content) # Compile the last action message + + # Append summary messages to history + summary_message = ChatMessage(role="assistant", content=summary) + self.messages.extend([trigger_summary, summary_message]) + + return summary + except Exception as e: + logger.error(f"Error calling Vision LLM for summary: {e}") + return f"Error generating summary: {e}" + + def compile_response(self, response_content: Union[str, List, dict]) -> str: + """Compile response content into a string, handling OpenAI message structures.""" + if isinstance(response_content, str): + return response_content # Simple text (e.g., initial user query, vision response) + + if isinstance(response_content, dict) and response_content.get("role") == "assistant": + # Grounding LLM response message (might contain tool calls) + text_content = response_content.get("content") + tool_calls = response_content.get("tool_calls") + compiled = [] + if text_content: + compiled.append(text_content) + if tool_calls: + for tc in tool_calls: + compiled.append( + f"**Action ({tc.get('function', {}).get('name')})**: {tc.get('function', {}).get('arguments')}" + ) + return "\n- ".join(filter(None, compiled)) or "[Assistant Message]" + + if isinstance(response_content, list): # Tool results list + compiled = ["**Tool Results**:"] + for item in response_content: + if isinstance(item, dict) and item.get("role") == "tool": + compiled.append(f" - ID {item.get('tool_call_id')}: {item.get('content')}") + else: + compiled.append(f" - {str(item)}") # Fallback + return "\n".join(compiled) + + # Fallback for unexpected types + return str(response_content) + + def _render_response(self, response: List, screenshot: Optional[str]) -> Optional[str]: + """Render response for display. Currently uses compile_response.""" + # TODO: Could potentially enhance rendering, e.g., showing vision thought + grounding actions distinctly. + # For now, rely on the structure built during the 'act' phase. + return response # The rendered_response is already built in act() + + def _format_message_for_api(self, messages: list[ChatMessage]) -> List[dict]: + """Format message history for OpenAI API calls.""" + formatted_messages = [] + for message in messages: + role = message.role + content = message.content + + if role == "environment": # Handle action results + formatted_messages.append({"role": "user", "content": content}) + else: + formatted_messages.append({"role": role, "content": content}) + return formatted_messages + + def _update_vision_usage(self, input_tokens: int, output_tokens: int): + self.vision_usage = get_chat_usage_metrics( + self.vision_chat_model.name, input_tokens, output_tokens, usage=self.vision_usage + ) + self._combine_usage() + + def _update_grounding_usage(self, input_tokens: int, output_tokens: int): + self.grounding_usage = get_chat_usage_metrics( + self.grounding_chat_model.name, input_tokens, output_tokens, usage=self.grounding_usage + ) + self._combine_usage() + + def _combine_usage(self): + """Combine usage from both models into the main tracer.""" + combined = {} + for usage_dict in [self.vision_usage, self.grounding_usage]: + for model, metrics in usage_dict.items(): + if model not in combined: + combined[model] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + combined[model]["input_tokens"] += metrics.get("input_tokens", 0) + combined[model]["output_tokens"] += metrics.get("output_tokens", 0) + combined[model]["total_tokens"] += metrics.get("total_tokens", 0) + self.tracer["usage"] = combined + logger.debug(f"Combined Operator usage: {self.tracer['usage']}") + + def reset(self): + """Reset the agent state.""" + super().reset() + self.vision_usage = {} + self.grounding_usage = {} + + # --- Main Operator Function --- async def operate_browser( query: str, @@ -1195,23 +1799,35 @@ async def operate_browser( # Get the agent chat model agent_chat_model = await AgentAdapters.aget_agent_chat_model(agent, user) if agent else None - chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model) - supported_operator_model_types = [ChatModel.ModelType.OPENAI, ChatModel.ModelType.ANTHROPIC] + default_chat_model: ChatModel = await ConversationAdapters.aget_default_chat_model(user, agent_chat_model) + vision_chat_model = await ConversationAdapters.aget_vision_enabled_config() + chat_model = default_chat_model or vision_chat_model - if not chat_model or chat_model.model_type not in supported_operator_model_types: - raise ValueError( - f"Unsupported AI model. Configure and use chat model of type {supported_operator_model_types} to enable Browser use." - ) + if not chat_model: + raise ValueError(f"Unsupported AI model. Configure and use a vision chat model to enable Browser use.") # Initialize Agent max_iterations = 40 # TODO: Configurable? operator_agent: OperatorAgent - if chat_model.model_type == ChatModel.ModelType.OPENAI: + if chat_model.name.startswith("gpt-"): operator_agent = OpenAIOperatorAgent(chat_model, max_iterations, tracer) - elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: + elif chat_model.name.startswith("claude-"): operator_agent = AnthropicOperatorAgent(chat_model, max_iterations, tracer) - else: # Should not happen due to check above, but satisfy type checker - raise ValueError("Invalid model type for operator agent.") + else: + grounding_model_name = "ui-tars-1.5-7b" + vision_model = await ConversationAdapters.aget_chat_model_by_name(chat_model.name) + grounding_model = await ConversationAdapters.aget_chat_model_by_name( + grounding_model_name + ) # Fetch grounding model + if ( + not grounding_model + or grounding_model.model_type != ChatModel.ModelType.OPENAI + or not grounding_model.vision_enabled + ): + raise ValueError("Grounding model for MultiLLMOperatorAgent not found or supported.") + if not vision_model or vision_model.model_type != ChatModel.ModelType.OPENAI or not vision_model.vision_enabled: + raise ValueError("Vision model for MultiLLMOperatorAgent not found or supported.") + operator_agent = BinaryOperatorAgent(vision_model, grounding_model, max_iterations, tracer) # Initialize Environment if send_status_func: diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 17a48121..2b4fcb83 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -586,6 +586,20 @@ def convert_image_data_uri(image_data_uri: str, target_format: str = "png") -> s return output_data_uri +def convert_image_to_png(image_base64: str) -> str: + """Convert base64 image to png format for wider support""" + image_bytes = base64.b64decode(image_base64) + image_io = io.BytesIO(image_bytes) + with Image.open(image_io) as original_image: + output_image_io = io.BytesIO() + original_image.save(output_image_io, "PNG") + + # Encode the WebP image back to base64 + output_image_bytes = output_image_io.getvalue() + output_image_io.close() + return base64.b64encode(output_image_bytes).decode("utf-8") + + def truncate_code_context(original_code_results: dict[str, Any], max_chars=10000) -> dict[str, Any]: """ Truncate large output files and drop image file data from code results.