Add support for Cerebras ai model api

- It does not support strict mode for json schema, tool use
- It likes text content to be plain string, not nested in a dictionary
- Verified to work with gpt oss models on cerebras
This commit is contained in:
Debanjum
2025-08-28 01:14:36 -07:00
parent 0a5a882e54
commit dd8e805cfe
2 changed files with 18 additions and 7 deletions

View File

@@ -8,6 +8,7 @@ from khoj.processor.conversation.openai.utils import (
clean_response_schema,
completion_with_backoff,
get_structured_output_support,
is_cerebras_api,
responses_chat_completion_with_backoff,
responses_completion_with_backoff,
supports_responses_api,
@@ -40,8 +41,11 @@ def send_message_to_model(
model_kwargs: Dict[str, Any] = {}
json_support = get_structured_output_support(model, api_base_url)
strict = not is_cerebras_api(api_base_url)
if tools and json_support == StructuredOutputSupport.TOOL:
model_kwargs["tools"] = to_openai_tools(tools, use_responses_api=supports_responses_api(model, api_base_url))
model_kwargs["tools"] = to_openai_tools(
tools, use_responses_api=supports_responses_api(model, api_base_url), strict=strict
)
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
# Drop unsupported fields from schema passed to OpenAI APi
cleaned_response_schema = clean_response_schema(response_schema)
@@ -49,7 +53,7 @@ def send_message_to_model(
model_kwargs["text"] = {
"format": {
"type": "json_schema",
"strict": True,
"strict": strict,
"name": response_schema.__name__,
"schema": cleaned_response_schema,
}
@@ -60,7 +64,7 @@ def send_message_to_model(
"json_schema": {
"schema": cleaned_response_schema,
"name": response_schema.__name__,
"strict": True,
"strict": strict,
},
}
elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT:

View File

@@ -795,7 +795,7 @@ def format_message_for_api(raw_messages: List[ChatMessage], model_name: str, api
if (
part.get("type") == "text"
and message.role == "assistant"
and api_base_url.startswith("https://api.deepinfra.com/v1")
and (api_base_url.startswith("https://api.deepinfra.com/v1") or is_cerebras_api(api_base_url))
):
assistant_texts += [part["text"]]
message.content.pop(idx)
@@ -876,6 +876,13 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo
)
def is_cerebras_api(api_base_url: str = None) -> bool:
"""
Check if the model is served over the Cerebras API
"""
return api_base_url is not None and api_base_url.startswith("https://api.cerebras.ai/v1")
def is_groq_api(api_base_url: str = None) -> bool:
"""
Check if the model is served over the Groq API
@@ -1212,7 +1219,7 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None:
break
def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> List[Dict] | None:
def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool, strict: bool) -> List[Dict] | None:
"Transform tool definitions from standard format to OpenAI format."
if use_responses_api:
openai_tools = [
@@ -1221,7 +1228,7 @@ def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> Lis
"name": tool.name,
"description": tool.description,
"parameters": clean_response_schema(tool.schema),
"strict": True,
"strict": strict,
}
for tool in tools
]
@@ -1233,7 +1240,7 @@ def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> Lis
"name": tool.name,
"description": tool.description,
"parameters": clean_response_schema(tool.schema),
"strict": True,
"strict": strict,
},
}
for tool in tools