mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Only enforce json output in supported AI model APIs
Deepseek reasoner does not support json object or schema via deepseek API Azure Ai API does not support json schema Resolves #1126
This commit is contained in:
@@ -10,8 +10,10 @@ from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.openai.utils import (
|
||||
chat_completion_with_backoff,
|
||||
completion_with_backoff,
|
||||
get_openai_api_json_support,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
clean_json,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
@@ -126,13 +128,14 @@ def send_message_to_model(
|
||||
"""
|
||||
|
||||
# Get Response from GPT
|
||||
json_support = get_openai_api_json_support(model, api_base_url)
|
||||
return completion_with_backoff(
|
||||
messages=messages,
|
||||
model_name=model,
|
||||
openai_api_key=api_key,
|
||||
temperature=temperature,
|
||||
api_base_url=api_base_url,
|
||||
model_kwargs={"response_format": {"type": response_type}},
|
||||
model_kwargs={"response_format": {"type": response_type}} if json_support >= JsonSupport.OBJECT else {},
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import os
|
||||
from threading import Thread
|
||||
from typing import Dict, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
@@ -16,6 +17,7 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from khoj.processor.conversation.utils import (
|
||||
JsonSupport,
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
)
|
||||
@@ -245,3 +247,13 @@ def llm_thread(
|
||||
logger.error(f"Error in llm_thread: {e}", exc_info=True)
|
||||
finally:
|
||||
g.close()
|
||||
|
||||
|
||||
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
||||
if model_name.startswith("deepseek-reasoner"):
|
||||
return JsonSupport.NONE
|
||||
if api_base_url:
|
||||
host = urlparse(api_base_url).hostname
|
||||
if host and host.endswith(".ai.azure.com"):
|
||||
return JsonSupport.OBJECT
|
||||
return JsonSupport.SCHEMA
|
||||
|
||||
@@ -878,3 +878,9 @@ def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str:
|
||||
return str(content)
|
||||
|
||||
return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages])
|
||||
|
||||
|
||||
class JsonSupport(int, Enum):
|
||||
NONE = 0
|
||||
OBJECT = 1
|
||||
SCHEMA = 2
|
||||
|
||||
Reference in New Issue
Block a user