diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 77dee0c4..bc3741ef 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -1,5 +1,6 @@ import logging import os +from copy import deepcopy from functools import partial from time import perf_counter from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union @@ -7,6 +8,7 @@ from urllib.parse import urlparse import httpx import openai +from langchain_core.messages.chat import ChatMessage from openai.lib.streaming.chat import ( ChatCompletionStream, ChatCompletionStreamEvent, @@ -32,6 +34,7 @@ from khoj.processor.conversation.utils import ( commit_conversation_trace, ) from khoj.utils.helpers import ( + convert_image_data_uri, get_chat_usage_metrics, get_openai_async_client, get_openai_client, @@ -58,7 +61,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {} reraise=True, ) def completion_with_backoff( - messages, + messages: List[ChatMessage], model_name: str, temperature=0.8, openai_api_key=None, @@ -74,7 +77,7 @@ def completion_with_backoff( openai_clients[client_key] = client stream_processor = default_stream_processor - formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + formatted_messages = format_message_for_api(messages, api_base_url) # Tune reasoning models arguments if is_openai_reasoning_model(model_name, api_base_url): @@ -173,7 +176,7 @@ async def chat_completion_with_backoff( openai_async_clients[client_key] = client stream_processor = adefault_stream_processor - formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + formatted_messages = format_message_for_api(messages, api_base_url) # Configure thinking for openai reasoning models if is_openai_reasoning_model(model_name, api_base_url): @@ -293,11 +296,34 @@ def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> Js return JsonSupport.SCHEMA +def format_message_for_api(messages: List[ChatMessage], api_base_url: str) -> List[dict]: + """ + Format messages to send to chat model served over OpenAI (compatible) API. + """ + formatted_messages = [] + for message in deepcopy(messages): + # Convert images to PNG format if message to be sent to non OpenAI API + if isinstance(message.content, list) and not is_openai_api(api_base_url): + for part in message.content: + if part.get("type") == "image_url": + part["image_url"]["url"] = convert_image_data_uri(part["image_url"]["url"], target_format="png") + formatted_messages.append({"role": message.role, "content": message.content}) + + return formatted_messages + + +def is_openai_api(api_base_url: str = None) -> bool: + """ + Check if the model is served over the official OpenAI API + """ + return api_base_url is None or api_base_url.startswith("https://api.openai.com/v1") + + def is_openai_reasoning_model(model_name: str, api_base_url: str = None) -> bool: """ Check if the model is an OpenAI reasoning model """ - return model_name.startswith("o") and (api_base_url is None or api_base_url.startswith("https://api.openai.com/v1")) + return model_name.startswith("o") and is_openai_api(api_base_url) def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> bool: diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 2530b4fb..0f986709 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -555,6 +555,32 @@ def convert_image_to_webp(image_bytes): return webp_image_bytes +def convert_image_data_uri(image_data_uri: str, target_format: str = "png") -> str: + """ + Convert image (in data URI) to target format. + + Target format can be png, jpg, webp etc. + Returns the converted image as a data URI. + """ + base64_data = image_data_uri.split(",", 1)[1] + image_type = image_data_uri.split(";")[0].split(":")[1].split("/")[1] + if image_type.lower() == target_format.lower(): + return image_data_uri + + image_bytes = base64.b64decode(base64_data) + image_io = io.BytesIO(image_bytes) + with Image.open(image_io) as original_image: + output_image_io = io.BytesIO() + original_image.save(output_image_io, target_format.upper()) + + # Encode the image back to base64 + output_image_bytes = output_image_io.getvalue() + output_image_io.close() + output_base64_data = base64.b64encode(output_image_bytes).decode("utf-8") + output_data_uri = f"data:image/{target_format};base64,{output_base64_data}" + return output_data_uri + + def truncate_code_context(original_code_results: dict[str, Any], max_chars=10000) -> dict[str, Any]: """ Truncate large output files and drop image file data from code results.