From 7827d317b421c16fa07e447a1c73319412876cba Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sun, 11 May 2025 11:13:27 -0600 Subject: [PATCH] Widen vision support for chat models served via openai compatible api Send image as png to non-openai models served via an openai compatible api. As more models support png than webp. Continue storing images as webp on server for efficiency. Convert to png at the openai api layer and only for non-openai models served via an openai compatible api. Enable using vision models like ui-tars (via llama.cpp server), grok. --- .../processor/conversation/openai/utils.py | 34 ++++++++++++++++--- src/khoj/utils/helpers.py | 26 ++++++++++++++ 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 77dee0c4..bc3741ef 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -1,5 +1,6 @@ import logging import os +from copy import deepcopy from functools import partial from time import perf_counter from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union @@ -7,6 +8,7 @@ from urllib.parse import urlparse import httpx import openai +from langchain_core.messages.chat import ChatMessage from openai.lib.streaming.chat import ( ChatCompletionStream, ChatCompletionStreamEvent, @@ -32,6 +34,7 @@ from khoj.processor.conversation.utils import ( commit_conversation_trace, ) from khoj.utils.helpers import ( + convert_image_data_uri, get_chat_usage_metrics, get_openai_async_client, get_openai_client, @@ -58,7 +61,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {} reraise=True, ) def completion_with_backoff( - messages, + messages: List[ChatMessage], model_name: str, temperature=0.8, openai_api_key=None, @@ -74,7 +77,7 @@ def completion_with_backoff( openai_clients[client_key] = client stream_processor = default_stream_processor - formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + formatted_messages = format_message_for_api(messages, api_base_url) # Tune reasoning models arguments if is_openai_reasoning_model(model_name, api_base_url): @@ -173,7 +176,7 @@ async def chat_completion_with_backoff( openai_async_clients[client_key] = client stream_processor = adefault_stream_processor - formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + formatted_messages = format_message_for_api(messages, api_base_url) # Configure thinking for openai reasoning models if is_openai_reasoning_model(model_name, api_base_url): @@ -293,11 +296,34 @@ def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> Js return JsonSupport.SCHEMA +def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> List[dict]: + """ + Format messages to send to chat model served over OpenAI (compatible) API. + """ + formatted_messages = [] + for message in deepcopy(messages): + # Convert images to PNG format if message to be sent to non OpenAI API + if isinstance(message.content, list) and not is_openai_api(api_base_url): + for part in message.content: + if part.get("type") == "image_url": + part["image_url"]["url"] = convert_image_data_uri(part["image_url"]["url"], target_format="png") + formatted_messages.append({"role": message.role, "content": message.content}) + + return formatted_messages + + +def is_openai_api(api_base_url: str = None) -> bool: + """ + Check if the model is served over the official OpenAI API + """ + return api_base_url is None or api_base_url.startswith("https://api.openai.com/v1") + + def is_openai_reasoning_model(model_name: str, api_base_url: str = None) -> bool: """ Check if the model is an OpenAI reasoning model """ - return model_name.startswith("o") and (api_base_url is None or api_base_url.startswith("https://api.openai.com/v1")) + return model_name.startswith("o") and is_openai_api(api_base_url) def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> bool: diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 2530b4fb..0f986709 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -555,6 +555,32 @@ def convert_image_to_webp(image_bytes): return webp_image_bytes +def convert_image_data_uri(image_data_uri: str, target_format: str = "png") -> str: + """ + Convert image (in data URI) to target format. + + Target format can be png, jpg, webp etc. + Returns the converted image as a data URI. + """ + base64_data = image_data_uri.split(",", 1)[1] + image_type = image_data_uri.split(";")[0].split(":")[1].split("/")[1] + if image_type.lower() == target_format.lower(): + return image_data_uri + + image_bytes = base64.b64decode(base64_data) + image_io = io.BytesIO(image_bytes) + with Image.open(image_io) as original_image: + output_image_io = io.BytesIO() + original_image.save(output_image_io, target_format.upper()) + + # Encode the image back to base64 + output_image_bytes = output_image_io.getvalue() + output_image_io.close() + output_base64_data = base64.b64encode(output_image_bytes).decode("utf-8") + output_data_uri = f"data:image/{target_format};base64,{output_base64_data}" + return output_data_uri + + def truncate_code_context(original_code_results: dict[str, Any], max_chars=10000) -> dict[str, Any]: """ Truncate large output files and drop image file data from code results.