mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Check if openai compatible ai api supports the responses api endpoint
Responses API is starting to get supported by other ai apis as well. This change does preparatory improvements to ease moving to use responses api with other ai apis. Use the new, better named `supports_responses_api' method. The method currently just maps to `is_openai_api'. It will add other ai apis once support for using responses api with them is added.
This commit is contained in:
@@ -8,9 +8,9 @@ from khoj.processor.conversation.openai.utils import (
|
|||||||
clean_response_schema,
|
clean_response_schema,
|
||||||
completion_with_backoff,
|
completion_with_backoff,
|
||||||
get_structured_output_support,
|
get_structured_output_support,
|
||||||
is_openai_api,
|
|
||||||
responses_chat_completion_with_backoff,
|
responses_chat_completion_with_backoff,
|
||||||
responses_completion_with_backoff,
|
responses_completion_with_backoff,
|
||||||
|
supports_responses_api,
|
||||||
to_openai_tools,
|
to_openai_tools,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
@@ -41,11 +41,11 @@ def send_message_to_model(
|
|||||||
model_kwargs: Dict[str, Any] = {}
|
model_kwargs: Dict[str, Any] = {}
|
||||||
json_support = get_structured_output_support(model, api_base_url)
|
json_support = get_structured_output_support(model, api_base_url)
|
||||||
if tools and json_support == StructuredOutputSupport.TOOL:
|
if tools and json_support == StructuredOutputSupport.TOOL:
|
||||||
model_kwargs["tools"] = to_openai_tools(tools, use_responses_api=is_openai_api(api_base_url))
|
model_kwargs["tools"] = to_openai_tools(tools, use_responses_api=supports_responses_api(model, api_base_url))
|
||||||
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
|
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
|
||||||
# Drop unsupported fields from schema passed to OpenAI APi
|
# Drop unsupported fields from schema passed to OpenAI APi
|
||||||
cleaned_response_schema = clean_response_schema(response_schema)
|
cleaned_response_schema = clean_response_schema(response_schema)
|
||||||
if is_openai_api(api_base_url):
|
if supports_responses_api(model, api_base_url):
|
||||||
model_kwargs["text"] = {
|
model_kwargs["text"] = {
|
||||||
"format": {
|
"format": {
|
||||||
"type": "json_schema",
|
"type": "json_schema",
|
||||||
@@ -67,7 +67,7 @@ def send_message_to_model(
|
|||||||
model_kwargs["response_format"] = {"type": response_type}
|
model_kwargs["response_format"] = {"type": response_type}
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
if is_openai_api(api_base_url):
|
if supports_responses_api(model, api_base_url):
|
||||||
return responses_completion_with_backoff(
|
return responses_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
@@ -106,7 +106,7 @@ async def converse_openai(
|
|||||||
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
if is_openai_api(api_base_url):
|
if supports_responses_api(model, api_base_url):
|
||||||
async for chunk in responses_chat_completion_with_backoff(
|
async for chunk in responses_chat_completion_with_backoff(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ def completion_with_backoff(
|
|||||||
model_kwargs["temperature"] = temperature
|
model_kwargs["temperature"] = temperature
|
||||||
model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95)
|
model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95)
|
||||||
|
|
||||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
formatted_messages = format_message_for_api(messages, model_name, api_base_url)
|
||||||
|
|
||||||
# Tune reasoning models arguments
|
# Tune reasoning models arguments
|
||||||
if is_openai_reasoning_model(model_name, api_base_url):
|
if is_openai_reasoning_model(model_name, api_base_url):
|
||||||
@@ -296,7 +296,7 @@ async def chat_completion_with_backoff(
|
|||||||
|
|
||||||
model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95)
|
model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95)
|
||||||
|
|
||||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
formatted_messages = format_message_for_api(messages, model_name, api_base_url)
|
||||||
|
|
||||||
# Configure thinking for openai reasoning models
|
# Configure thinking for openai reasoning models
|
||||||
if is_openai_reasoning_model(model_name, api_base_url):
|
if is_openai_reasoning_model(model_name, api_base_url):
|
||||||
@@ -448,7 +448,7 @@ def responses_completion_with_backoff(
|
|||||||
client = get_openai_client(openai_api_key, api_base_url)
|
client = get_openai_client(openai_api_key, api_base_url)
|
||||||
openai_clients[client_key] = client
|
openai_clients[client_key] = client
|
||||||
|
|
||||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
formatted_messages = format_message_for_api(messages, model_name, api_base_url)
|
||||||
# Move the first system message to Responses API instructions
|
# Move the first system message to Responses API instructions
|
||||||
instructions: Optional[str] = None
|
instructions: Optional[str] = None
|
||||||
if formatted_messages and formatted_messages[0].get("role") == "system":
|
if formatted_messages and formatted_messages[0].get("role") == "system":
|
||||||
@@ -461,8 +461,10 @@ def responses_completion_with_backoff(
|
|||||||
if is_openai_reasoning_model(model_name, api_base_url):
|
if is_openai_reasoning_model(model_name, api_base_url):
|
||||||
temperature = 1
|
temperature = 1
|
||||||
reasoning_effort = "medium" if deepthought else "low"
|
reasoning_effort = "medium" if deepthought else "low"
|
||||||
model_kwargs["reasoning"] = {"effort": reasoning_effort, "summary": "auto"}
|
model_kwargs["reasoning"] = {"effort": reasoning_effort}
|
||||||
model_kwargs["include"] = ["reasoning.encrypted_content"]
|
if is_openai_api(api_base_url):
|
||||||
|
model_kwargs["reasoning"]["summary"] = "auto"
|
||||||
|
model_kwargs["include"] = ["reasoning.encrypted_content"]
|
||||||
# Remove unsupported params for reasoning models
|
# Remove unsupported params for reasoning models
|
||||||
model_kwargs.pop("top_p", None)
|
model_kwargs.pop("top_p", None)
|
||||||
model_kwargs.pop("stop", None)
|
model_kwargs.pop("stop", None)
|
||||||
@@ -559,7 +561,7 @@ async def responses_chat_completion_with_backoff(
|
|||||||
client = get_openai_async_client(openai_api_key, api_base_url)
|
client = get_openai_async_client(openai_api_key, api_base_url)
|
||||||
openai_async_clients[client_key] = client
|
openai_async_clients[client_key] = client
|
||||||
|
|
||||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
formatted_messages = format_message_for_api(messages, model_name, api_base_url)
|
||||||
# Move the first system message to Responses API instructions
|
# Move the first system message to Responses API instructions
|
||||||
instructions: Optional[str] = None
|
instructions: Optional[str] = None
|
||||||
if formatted_messages and formatted_messages[0].get("role") == "system":
|
if formatted_messages and formatted_messages[0].get("role") == "system":
|
||||||
@@ -572,7 +574,10 @@ async def responses_chat_completion_with_backoff(
|
|||||||
if is_openai_reasoning_model(model_name, api_base_url):
|
if is_openai_reasoning_model(model_name, api_base_url):
|
||||||
temperature = 1
|
temperature = 1
|
||||||
reasoning_effort = "medium" if deepthought else "low"
|
reasoning_effort = "medium" if deepthought else "low"
|
||||||
model_kwargs["reasoning"] = {"effort": reasoning_effort, "summary": "auto"}
|
model_kwargs["reasoning"] = {"effort": reasoning_effort}
|
||||||
|
if is_openai_api(api_base_url):
|
||||||
|
model_kwargs["reasoning"]["summary"] = "auto"
|
||||||
|
model_kwargs["include"] = ["reasoning.encrypted_content"]
|
||||||
# Remove unsupported params for reasoning models
|
# Remove unsupported params for reasoning models
|
||||||
model_kwargs.pop("top_p", None)
|
model_kwargs.pop("top_p", None)
|
||||||
model_kwargs.pop("stop", None)
|
model_kwargs.pop("stop", None)
|
||||||
@@ -705,7 +710,7 @@ def get_structured_output_support(model_name: str, api_base_url: str = None) ->
|
|||||||
return StructuredOutputSupport.TOOL
|
return StructuredOutputSupport.TOOL
|
||||||
|
|
||||||
|
|
||||||
def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) -> List[dict]:
|
def format_message_for_api(raw_messages: List[ChatMessage], model_name: str, api_base_url: str) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Format messages to send to chat model served over OpenAI (compatible) API.
|
Format messages to send to chat model served over OpenAI (compatible) API.
|
||||||
"""
|
"""
|
||||||
@@ -715,7 +720,7 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) -
|
|||||||
# Handle tool call and tool result message types
|
# Handle tool call and tool result message types
|
||||||
message_type = message.additional_kwargs.get("message_type")
|
message_type = message.additional_kwargs.get("message_type")
|
||||||
if message_type == "tool_call":
|
if message_type == "tool_call":
|
||||||
if is_openai_api(api_base_url):
|
if supports_responses_api(model_name, api_base_url):
|
||||||
for part in message.content:
|
for part in message.content:
|
||||||
if "status" in part:
|
if "status" in part:
|
||||||
part.pop("status") # Drop unsupported tool call status field
|
part.pop("status") # Drop unsupported tool call status field
|
||||||
@@ -759,7 +764,7 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) -
|
|||||||
if not tool_call_id:
|
if not tool_call_id:
|
||||||
logger.warning(f"Dropping tool result without valid tool_call_id: {part.get('name')}")
|
logger.warning(f"Dropping tool result without valid tool_call_id: {part.get('name')}")
|
||||||
continue
|
continue
|
||||||
if is_openai_api(api_base_url):
|
if supports_responses_api(model_name, api_base_url):
|
||||||
formatted_messages.append(
|
formatted_messages.append(
|
||||||
{
|
{
|
||||||
"type": "function_call_output",
|
"type": "function_call_output",
|
||||||
@@ -777,7 +782,7 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) -
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
if isinstance(message.content, list) and not is_openai_api(api_base_url):
|
if isinstance(message.content, list) and not supports_responses_api(model_name, api_base_url):
|
||||||
assistant_texts = []
|
assistant_texts = []
|
||||||
has_images = False
|
has_images = False
|
||||||
for idx, part in enumerate(message.content):
|
for idx, part in enumerate(message.content):
|
||||||
@@ -833,6 +838,13 @@ def is_openai_api(api_base_url: str = None) -> bool:
|
|||||||
return api_base_url is None or api_base_url.startswith("https://api.openai.com/v1")
|
return api_base_url is None or api_base_url.startswith("https://api.openai.com/v1")
|
||||||
|
|
||||||
|
|
||||||
|
def supports_responses_api(model_name: str, api_base_url: str = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the model, ai api supports the OpenAI Responses API
|
||||||
|
"""
|
||||||
|
return is_openai_api(api_base_url)
|
||||||
|
|
||||||
|
|
||||||
def is_openai_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
|
def is_openai_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the model is an OpenAI reasoning model
|
Check if the model is an OpenAI reasoning model
|
||||||
|
|||||||
Reference in New Issue
Block a user