mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Add support for Cerebras ai model api
- It does not support strict mode for json schema, tool use - It likes text content to be plain string, not nested in a dictionary - Verified to work with gpt oss models on cerebras
This commit is contained in:
@@ -8,6 +8,7 @@ from khoj.processor.conversation.openai.utils import (
|
|||||||
clean_response_schema,
|
clean_response_schema,
|
||||||
completion_with_backoff,
|
completion_with_backoff,
|
||||||
get_structured_output_support,
|
get_structured_output_support,
|
||||||
|
is_cerebras_api,
|
||||||
responses_chat_completion_with_backoff,
|
responses_chat_completion_with_backoff,
|
||||||
responses_completion_with_backoff,
|
responses_completion_with_backoff,
|
||||||
supports_responses_api,
|
supports_responses_api,
|
||||||
@@ -40,8 +41,11 @@ def send_message_to_model(
|
|||||||
|
|
||||||
model_kwargs: Dict[str, Any] = {}
|
model_kwargs: Dict[str, Any] = {}
|
||||||
json_support = get_structured_output_support(model, api_base_url)
|
json_support = get_structured_output_support(model, api_base_url)
|
||||||
|
strict = not is_cerebras_api(api_base_url)
|
||||||
if tools and json_support == StructuredOutputSupport.TOOL:
|
if tools and json_support == StructuredOutputSupport.TOOL:
|
||||||
model_kwargs["tools"] = to_openai_tools(tools, use_responses_api=supports_responses_api(model, api_base_url))
|
model_kwargs["tools"] = to_openai_tools(
|
||||||
|
tools, use_responses_api=supports_responses_api(model, api_base_url), strict=strict
|
||||||
|
)
|
||||||
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
|
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
|
||||||
# Drop unsupported fields from schema passed to OpenAI APi
|
# Drop unsupported fields from schema passed to OpenAI APi
|
||||||
cleaned_response_schema = clean_response_schema(response_schema)
|
cleaned_response_schema = clean_response_schema(response_schema)
|
||||||
@@ -49,7 +53,7 @@ def send_message_to_model(
|
|||||||
model_kwargs["text"] = {
|
model_kwargs["text"] = {
|
||||||
"format": {
|
"format": {
|
||||||
"type": "json_schema",
|
"type": "json_schema",
|
||||||
"strict": True,
|
"strict": strict,
|
||||||
"name": response_schema.__name__,
|
"name": response_schema.__name__,
|
||||||
"schema": cleaned_response_schema,
|
"schema": cleaned_response_schema,
|
||||||
}
|
}
|
||||||
@@ -60,7 +64,7 @@ def send_message_to_model(
|
|||||||
"json_schema": {
|
"json_schema": {
|
||||||
"schema": cleaned_response_schema,
|
"schema": cleaned_response_schema,
|
||||||
"name": response_schema.__name__,
|
"name": response_schema.__name__,
|
||||||
"strict": True,
|
"strict": strict,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT:
|
elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT:
|
||||||
|
|||||||
@@ -795,7 +795,7 @@ def format_message_for_api(raw_messages: List[ChatMessage], model_name: str, api
|
|||||||
if (
|
if (
|
||||||
part.get("type") == "text"
|
part.get("type") == "text"
|
||||||
and message.role == "assistant"
|
and message.role == "assistant"
|
||||||
and api_base_url.startswith("https://api.deepinfra.com/v1")
|
and (api_base_url.startswith("https://api.deepinfra.com/v1") or is_cerebras_api(api_base_url))
|
||||||
):
|
):
|
||||||
assistant_texts += [part["text"]]
|
assistant_texts += [part["text"]]
|
||||||
message.content.pop(idx)
|
message.content.pop(idx)
|
||||||
@@ -876,6 +876,13 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cerebras_api(api_base_url: str = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the model is served over the Cerebras API
|
||||||
|
"""
|
||||||
|
return api_base_url is not None and api_base_url.startswith("https://api.cerebras.ai/v1")
|
||||||
|
|
||||||
|
|
||||||
def is_groq_api(api_base_url: str = None) -> bool:
|
def is_groq_api(api_base_url: str = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the model is served over the Groq API
|
Check if the model is served over the Groq API
|
||||||
@@ -1212,7 +1219,7 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None:
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> List[Dict] | None:
|
def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool, strict: bool) -> List[Dict] | None:
|
||||||
"Transform tool definitions from standard format to OpenAI format."
|
"Transform tool definitions from standard format to OpenAI format."
|
||||||
if use_responses_api:
|
if use_responses_api:
|
||||||
openai_tools = [
|
openai_tools = [
|
||||||
@@ -1221,7 +1228,7 @@ def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> Lis
|
|||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"parameters": clean_response_schema(tool.schema),
|
"parameters": clean_response_schema(tool.schema),
|
||||||
"strict": True,
|
"strict": strict,
|
||||||
}
|
}
|
||||||
for tool in tools
|
for tool in tools
|
||||||
]
|
]
|
||||||
@@ -1233,7 +1240,7 @@ def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> Lis
|
|||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"parameters": clean_response_schema(tool.schema),
|
"parameters": clean_response_schema(tool.schema),
|
||||||
"strict": True,
|
"strict": strict,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for tool in tools
|
for tool in tools
|
||||||
|
|||||||
Reference in New Issue
Block a user