Widen vision support for chat models served via openai compatible api

Send image as png to non-openai models served via an openai compatible
api. As more models support png than webp.

Continue storing images as webp on server for efficiency.

Convert to png at the openai api layer and only for non-openai models
served via an openai compatible api.

Enable using vision models like ui-tars (via llama.cpp server), grok.
This commit is contained in:
Debanjum
2025-05-11 11:13:27 -06:00
parent 4f3fdaf19d
commit 7827d317b4
2 changed files with 56 additions and 4 deletions

View File

@@ -1,5 +1,6 @@
import logging
import os
from copy import deepcopy
from functools import partial
from time import perf_counter
from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union
@@ -7,6 +8,7 @@ from urllib.parse import urlparse
import httpx
import openai
from langchain_core.messages.chat import ChatMessage
from openai.lib.streaming.chat import (
ChatCompletionStream,
ChatCompletionStreamEvent,
@@ -32,6 +34,7 @@ from khoj.processor.conversation.utils import (
commit_conversation_trace,
)
from khoj.utils.helpers import (
convert_image_data_uri,
get_chat_usage_metrics,
get_openai_async_client,
get_openai_client,
@@ -58,7 +61,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
reraise=True,
)
def completion_with_backoff(
messages,
messages: List[ChatMessage],
model_name: str,
temperature=0.8,
openai_api_key=None,
@@ -74,7 +77,7 @@ def completion_with_backoff(
openai_clients[client_key] = client
stream_processor = default_stream_processor
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
formatted_messages = format_message_for_api(messages, api_base_url)
# Tune reasoning models arguments
if is_openai_reasoning_model(model_name, api_base_url):
@@ -173,7 +176,7 @@ async def chat_completion_with_backoff(
openai_async_clients[client_key] = client
stream_processor = adefault_stream_processor
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
formatted_messages = format_message_for_api(messages, api_base_url)
# Configure thinking for openai reasoning models
if is_openai_reasoning_model(model_name, api_base_url):
@@ -293,11 +296,34 @@ def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> Js
return JsonSupport.SCHEMA
def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> List[dict]:
"""
Format messages to send to chat model served over OpenAI (compatible) API.
"""
formatted_messages = []
for message in deepcopy(messages):
# Convert images to PNG format if message to be sent to non OpenAI API
if isinstance(message.content, list) and not is_openai_api(api_base_url):
for part in message.content:
if part.get("type") == "image_url":
part["image_url"]["url"] = convert_image_data_uri(part["image_url"]["url"], target_format="png")
formatted_messages.append({"role": message.role, "content": message.content})
return formatted_messages
def is_openai_api(api_base_url: str = None) -> bool:
"""
Check if the model is served over the official OpenAI API
"""
return api_base_url is None or api_base_url.startswith("https://api.openai.com/v1")
def is_openai_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
"""
Check if the model is an OpenAI reasoning model
"""
return model_name.startswith("o") and (api_base_url is None or api_base_url.startswith("https://api.openai.com/v1"))
return model_name.startswith("o") and is_openai_api(api_base_url)
def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> bool:

View File

@@ -555,6 +555,32 @@ def convert_image_to_webp(image_bytes):
return webp_image_bytes
def convert_image_data_uri(image_data_uri: str, target_format: str = "png") -> str:
"""
Convert image (in data URI) to target format.
Target format can be png, jpg, webp etc.
Returns the converted image as a data URI.
"""
base64_data = image_data_uri.split(",", 1)[1]
image_type = image_data_uri.split(";")[0].split(":")[1].split("/")[1]
if image_type.lower() == target_format.lower():
return image_data_uri
image_bytes = base64.b64decode(base64_data)
image_io = io.BytesIO(image_bytes)
with Image.open(image_io) as original_image:
output_image_io = io.BytesIO()
original_image.save(output_image_io, target_format.upper())
# Encode the image back to base64
output_image_bytes = output_image_io.getvalue()
output_image_io.close()
output_base64_data = base64.b64encode(output_image_bytes).decode("utf-8")
output_data_uri = f"data:image/{target_format};base64,{output_base64_data}"
return output_data_uri
def truncate_code_context(original_code_results: dict[str, Any], max_chars=10000) -> dict[str, Any]:
"""
Truncate large output files and drop image file data from code results.