diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 389f52ab..18eaea47 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -10,8 +10,10 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, completion_with_backoff, + get_openai_api_json_support, ) from khoj.processor.conversation.utils import ( + JsonSupport, clean_json, construct_structured_message, generate_chatml_messages_with_context, @@ -126,13 +128,14 @@ def send_message_to_model( """ # Get Response from GPT + json_support = get_openai_api_json_support(model, api_base_url) return completion_with_backoff( messages=messages, model_name=model, openai_api_key=api_key, temperature=temperature, api_base_url=api_base_url, - model_kwargs={"response_format": {"type": response_type}}, + model_kwargs={"response_format": {"type": response_type}} if json_support >= JsonSupport.OBJECT else {}, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 88d75763..f80c446a 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -2,6 +2,7 @@ import logging import os from threading import Thread from typing import Dict, List +from urllib.parse import urlparse import openai from openai.types.chat.chat_completion import ChatCompletion @@ -16,6 +17,7 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( + JsonSupport, ThreadedGenerator, commit_conversation_trace, ) @@ -245,3 +247,13 @@ def llm_thread( logger.error(f"Error in llm_thread: {e}", exc_info=True) finally: g.close() + + +def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: + if model_name.startswith("deepseek-reasoner"): + return JsonSupport.NONE + if api_base_url: + host = urlparse(api_base_url).hostname + if host and host.endswith(".ai.azure.com"): + return JsonSupport.OBJECT + return JsonSupport.SCHEMA diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index a7e6e694..de82f067 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -878,3 +878,9 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str: return str(content) return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages]) + + +class JsonSupport(int, Enum): + NONE = 0 + OBJECT = 1 + SCHEMA = 2