Only enforce json output in supported AI model APIs

Deepseek reasoner does not support json object or schema via deepseek API
Azure Ai API does not support json schema

Resolves #1126
This commit is contained in:
Debanjum
2025-03-19 19:28:54 +05:30
parent d74c3a1db4
commit 4a4d225455
3 changed files with 22 additions and 1 deletions

View File

@@ -10,8 +10,10 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import ( from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff, chat_completion_with_backoff,
completion_with_backoff, completion_with_backoff,
get_openai_api_json_support,
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
JsonSupport,
clean_json, clean_json,
construct_structured_message, construct_structured_message,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
@@ -126,13 +128,14 @@ def send_message_to_model(
""" """
# Get Response from GPT # Get Response from GPT
json_support = get_openai_api_json_support(model, api_base_url)
return completion_with_backoff( return completion_with_backoff(
messages=messages, messages=messages,
model_name=model, model_name=model,
openai_api_key=api_key, openai_api_key=api_key,
temperature=temperature, temperature=temperature,
api_base_url=api_base_url, api_base_url=api_base_url,
model_kwargs={"response_format": {"type": response_type}}, model_kwargs={"response_format": {"type": response_type}} if json_support >= JsonSupport.OBJECT else {},
tracer=tracer, tracer=tracer,
) )

View File

@@ -2,6 +2,7 @@ import logging
import os import os
from threading import Thread from threading import Thread
from typing import Dict, List from typing import Dict, List
from urllib.parse import urlparse
import openai import openai
from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion import ChatCompletion
@@ -16,6 +17,7 @@ from tenacity import (
) )
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
JsonSupport,
ThreadedGenerator, ThreadedGenerator,
commit_conversation_trace, commit_conversation_trace,
) )
@@ -245,3 +247,13 @@ def llm_thread(
logger.error(f"Error in llm_thread: {e}", exc_info=True) logger.error(f"Error in llm_thread: {e}", exc_info=True)
finally: finally:
g.close() g.close()
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
if model_name.startswith("deepseek-reasoner"):
return JsonSupport.NONE
if api_base_url:
host = urlparse(api_base_url).hostname
if host and host.endswith(".ai.azure.com"):
return JsonSupport.OBJECT
return JsonSupport.SCHEMA

View File

@@ -878,3 +878,9 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str:
return str(content) return str(content)
return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages]) return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages])
class JsonSupport(int, Enum):
NONE = 0
OBJECT = 1
SCHEMA = 2