mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Widen vision support for chat models served via openai compatible api
Send image as png to non-openai models served via an openai compatible api. As more models support png than webp. Continue storing images as webp on server for efficiency. Convert to png at the openai api layer and only for non-openai models served via an openai compatible api. Enable using vision models like ui-tars (via llama.cpp server), grok.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from time import perf_counter
|
||||
from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union
|
||||
@@ -7,6 +8,7 @@ from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from openai.lib.streaming.chat import (
|
||||
ChatCompletionStream,
|
||||
ChatCompletionStreamEvent,
|
||||
@@ -32,6 +34,7 @@ from khoj.processor.conversation.utils import (
|
||||
commit_conversation_trace,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
convert_image_data_uri,
|
||||
get_chat_usage_metrics,
|
||||
get_openai_async_client,
|
||||
get_openai_client,
|
||||
@@ -58,7 +61,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
|
||||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(
|
||||
messages,
|
||||
messages: List[ChatMessage],
|
||||
model_name: str,
|
||||
temperature=0.8,
|
||||
openai_api_key=None,
|
||||
@@ -74,7 +77,7 @@ def completion_with_backoff(
|
||||
openai_clients[client_key] = client
|
||||
|
||||
stream_processor = default_stream_processor
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||
|
||||
# Tune reasoning models arguments
|
||||
if is_openai_reasoning_model(model_name, api_base_url):
|
||||
@@ -173,7 +176,7 @@ async def chat_completion_with_backoff(
|
||||
openai_async_clients[client_key] = client
|
||||
|
||||
stream_processor = adefault_stream_processor
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||
|
||||
# Configure thinking for openai reasoning models
|
||||
if is_openai_reasoning_model(model_name, api_base_url):
|
||||
@@ -293,11 +296,34 @@ def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> Js
|
||||
return JsonSupport.SCHEMA
|
||||
|
||||
|
||||
def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> List[dict]:
|
||||
"""
|
||||
Format messages to send to chat model served over OpenAI (compatible) API.
|
||||
"""
|
||||
formatted_messages = []
|
||||
for message in deepcopy(messages):
|
||||
# Convert images to PNG format if message to be sent to non OpenAI API
|
||||
if isinstance(message.content, list) and not is_openai_api(api_base_url):
|
||||
for part in message.content:
|
||||
if part.get("type") == "image_url":
|
||||
part["image_url"]["url"] = convert_image_data_uri(part["image_url"]["url"], target_format="png")
|
||||
formatted_messages.append({"role": message.role, "content": message.content})
|
||||
|
||||
return formatted_messages
|
||||
|
||||
|
||||
def is_openai_api(api_base_url: str = None) -> bool:
|
||||
"""
|
||||
Check if the model is served over the official OpenAI API
|
||||
"""
|
||||
return api_base_url is None or api_base_url.startswith("https://api.openai.com/v1")
|
||||
|
||||
|
||||
def is_openai_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
|
||||
"""
|
||||
Check if the model is an OpenAI reasoning model
|
||||
"""
|
||||
return model_name.startswith("o") and (api_base_url is None or api_base_url.startswith("https://api.openai.com/v1"))
|
||||
return model_name.startswith("o") and is_openai_api(api_base_url)
|
||||
|
||||
|
||||
def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
|
||||
|
||||
@@ -555,6 +555,32 @@ def convert_image_to_webp(image_bytes):
|
||||
return webp_image_bytes
|
||||
|
||||
|
||||
def convert_image_data_uri(image_data_uri: str, target_format: str = "png") -> str:
|
||||
"""
|
||||
Convert image (in data URI) to target format.
|
||||
|
||||
Target format can be png, jpg, webp etc.
|
||||
Returns the converted image as a data URI.
|
||||
"""
|
||||
base64_data = image_data_uri.split(",", 1)[1]
|
||||
image_type = image_data_uri.split(";")[0].split(":")[1].split("/")[1]
|
||||
if image_type.lower() == target_format.lower():
|
||||
return image_data_uri
|
||||
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
image_io = io.BytesIO(image_bytes)
|
||||
with Image.open(image_io) as original_image:
|
||||
output_image_io = io.BytesIO()
|
||||
original_image.save(output_image_io, target_format.upper())
|
||||
|
||||
# Encode the image back to base64
|
||||
output_image_bytes = output_image_io.getvalue()
|
||||
output_image_io.close()
|
||||
output_base64_data = base64.b64encode(output_image_bytes).decode("utf-8")
|
||||
output_data_uri = f"data:image/{target_format};base64,{output_base64_data}"
|
||||
return output_data_uri
|
||||
|
||||
|
||||
def truncate_code_context(original_code_results: dict[str, Any], max_chars=10000) -> dict[str, Any]:
|
||||
"""
|
||||
Truncate large output files and drop image file data from code results.
|
||||
|
||||
Reference in New Issue
Block a user