From 4a4d225455f9b0962df422d557019bb4d7c3d34a Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 19 Mar 2025 19:28:54 +0530 Subject: [PATCH] Only enforce json output in supported AI model APIs Deepseek reasoner does not support json object or schema via deepseek API Azure Ai API does not support json schema Resolves #1126 --- src/khoj/processor/conversation/openai/gpt.py | 5 ++++- src/khoj/processor/conversation/openai/utils.py | 12 ++++++++++++ src/khoj/processor/conversation/utils.py | 6 ++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 389f52ab..18eaea47 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -10,8 +10,10 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, completion_with_backoff, + get_openai_api_json_support, ) from khoj.processor.conversation.utils import ( + JsonSupport, clean_json, construct_structured_message, generate_chatml_messages_with_context, @@ -126,13 +128,14 @@ def send_message_to_model( """ # Get Response from GPT + json_support = get_openai_api_json_support(model, api_base_url) return completion_with_backoff( messages=messages, model_name=model, openai_api_key=api_key, temperature=temperature, api_base_url=api_base_url, - model_kwargs={"response_format": {"type": response_type}}, + model_kwargs={"response_format": {"type": response_type}} if json_support >= JsonSupport.OBJECT else {}, tracer=tracer, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 88d75763..f80c446a 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -2,6 +2,7 @@ import logging import os from threading import Thread from typing import Dict, List +from urllib.parse import urlparse import openai from openai.types.chat.chat_completion import ChatCompletion @@ -16,6 +17,7 @@ from tenacity import ( ) from khoj.processor.conversation.utils import ( + JsonSupport, ThreadedGenerator, commit_conversation_trace, ) @@ -245,3 +247,13 @@ def llm_thread( logger.error(f"Error in llm_thread: {e}", exc_info=True) finally: g.close() + + +def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: + if model_name.startswith("deepseek-reasoner"): + return JsonSupport.NONE + if api_base_url: + host = urlparse(api_base_url).hostname + if host and host.endswith(".ai.azure.com"): + return JsonSupport.OBJECT + return JsonSupport.SCHEMA diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index a7e6e694..de82f067 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -878,3 +878,9 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str: return str(content) return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages]) + + +class JsonSupport(int, Enum): + NONE = 0 + OBJECT = 1 + SCHEMA = 2