mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Add support for Cerebras ai model api
- It does not support strict mode for json schema, tool use - It likes text content to be plain string, not nested in a dictionary - Verified to work with gpt oss models on cerebras
This commit is contained in:
@@ -8,6 +8,7 @@ from khoj.processor.conversation.openai.utils import (
|
||||
clean_response_schema,
|
||||
completion_with_backoff,
|
||||
get_structured_output_support,
|
||||
is_cerebras_api,
|
||||
responses_chat_completion_with_backoff,
|
||||
responses_completion_with_backoff,
|
||||
supports_responses_api,
|
||||
@@ -40,8 +41,11 @@ def send_message_to_model(
|
||||
|
||||
model_kwargs: Dict[str, Any] = {}
|
||||
json_support = get_structured_output_support(model, api_base_url)
|
||||
strict = not is_cerebras_api(api_base_url)
|
||||
if tools and json_support == StructuredOutputSupport.TOOL:
|
||||
model_kwargs["tools"] = to_openai_tools(tools, use_responses_api=supports_responses_api(model, api_base_url))
|
||||
model_kwargs["tools"] = to_openai_tools(
|
||||
tools, use_responses_api=supports_responses_api(model, api_base_url), strict=strict
|
||||
)
|
||||
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
|
||||
# Drop unsupported fields from schema passed to OpenAI APi
|
||||
cleaned_response_schema = clean_response_schema(response_schema)
|
||||
@@ -49,7 +53,7 @@ def send_message_to_model(
|
||||
model_kwargs["text"] = {
|
||||
"format": {
|
||||
"type": "json_schema",
|
||||
"strict": True,
|
||||
"strict": strict,
|
||||
"name": response_schema.__name__,
|
||||
"schema": cleaned_response_schema,
|
||||
}
|
||||
@@ -60,7 +64,7 @@ def send_message_to_model(
|
||||
"json_schema": {
|
||||
"schema": cleaned_response_schema,
|
||||
"name": response_schema.__name__,
|
||||
"strict": True,
|
||||
"strict": strict,
|
||||
},
|
||||
}
|
||||
elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT:
|
||||
|
||||
@@ -795,7 +795,7 @@ def format_message_for_api(raw_messages: List[ChatMessage], model_name: str, api
|
||||
if (
|
||||
part.get("type") == "text"
|
||||
and message.role == "assistant"
|
||||
and api_base_url.startswith("https://api.deepinfra.com/v1")
|
||||
and (api_base_url.startswith("https://api.deepinfra.com/v1") or is_cerebras_api(api_base_url))
|
||||
):
|
||||
assistant_texts += [part["text"]]
|
||||
message.content.pop(idx)
|
||||
@@ -876,6 +876,13 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo
|
||||
)
|
||||
|
||||
|
||||
def is_cerebras_api(api_base_url: str = None) -> bool:
|
||||
"""
|
||||
Check if the model is served over the Cerebras API
|
||||
"""
|
||||
return api_base_url is not None and api_base_url.startswith("https://api.cerebras.ai/v1")
|
||||
|
||||
|
||||
def is_groq_api(api_base_url: str = None) -> bool:
|
||||
"""
|
||||
Check if the model is served over the Groq API
|
||||
@@ -1212,7 +1219,7 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None:
|
||||
break
|
||||
|
||||
|
||||
def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> List[Dict] | None:
|
||||
def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool, strict: bool) -> List[Dict] | None:
|
||||
"Transform tool definitions from standard format to OpenAI format."
|
||||
if use_responses_api:
|
||||
openai_tools = [
|
||||
@@ -1221,7 +1228,7 @@ def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> Lis
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": clean_response_schema(tool.schema),
|
||||
"strict": True,
|
||||
"strict": strict,
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
@@ -1233,7 +1240,7 @@ def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> Lis
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": clean_response_schema(tool.schema),
|
||||
"strict": True,
|
||||
"strict": strict,
|
||||
},
|
||||
}
|
||||
for tool in tools
|
||||
|
||||
Reference in New Issue
Block a user