Only enforce json output in supported AI model APIs

Deepseek reasoner does not support json object or schema via deepseek API
Azure Ai API does not support json schema

Resolves #1126
This commit is contained in:
Debanjum
2025-03-19 19:28:54 +05:30
parent d74c3a1db4
commit 4a4d225455
3 changed files with 22 additions and 1 deletions

View File

@@ -10,8 +10,10 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
completion_with_backoff,
get_openai_api_json_support,
)
from khoj.processor.conversation.utils import (
JsonSupport,
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
@@ -126,13 +128,14 @@ def send_message_to_model(
"""
# Get Response from GPT
json_support = get_openai_api_json_support(model, api_base_url)
return completion_with_backoff(
messages=messages,
model_name=model,
openai_api_key=api_key,
temperature=temperature,
api_base_url=api_base_url,
model_kwargs={"response_format": {"type": response_type}},
model_kwargs={"response_format": {"type": response_type}} if json_support >= JsonSupport.OBJECT else {},
tracer=tracer,
)

View File

@@ -2,6 +2,7 @@ import logging
import os
from threading import Thread
from typing import Dict, List
from urllib.parse import urlparse
import openai
from openai.types.chat.chat_completion import ChatCompletion
@@ -16,6 +17,7 @@ from tenacity import (
)
from khoj.processor.conversation.utils import (
JsonSupport,
ThreadedGenerator,
commit_conversation_trace,
)
@@ -245,3 +247,13 @@ def llm_thread(
logger.error(f"Error in llm_thread: {e}", exc_info=True)
finally:
g.close()
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
if model_name.startswith("deepseek-reasoner"):
return JsonSupport.NONE
if api_base_url:
host = urlparse(api_base_url).hostname
if host and host.endswith(".ai.azure.com"):
return JsonSupport.OBJECT
return JsonSupport.SCHEMA

View File

@@ -878,3 +878,9 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str:
return str(content)
return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages])
class JsonSupport(int, Enum):
NONE = 0
OBJECT = 1
SCHEMA = 2