diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 42586f76..7a892bdf 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -8,6 +8,7 @@ from khoj.processor.conversation.openai.utils import ( clean_response_schema, completion_with_backoff, get_structured_output_support, + is_cerebras_api, responses_chat_completion_with_backoff, responses_completion_with_backoff, supports_responses_api, @@ -40,8 +41,11 @@ def send_message_to_model( model_kwargs: Dict[str, Any] = {} json_support = get_structured_output_support(model, api_base_url) + strict = not is_cerebras_api(api_base_url) if tools and json_support == StructuredOutputSupport.TOOL: - model_kwargs["tools"] = to_openai_tools(tools, use_responses_api=supports_responses_api(model, api_base_url)) + model_kwargs["tools"] = to_openai_tools( + tools, use_responses_api=supports_responses_api(model, api_base_url), strict=strict + ) elif response_schema and json_support >= StructuredOutputSupport.SCHEMA: # Drop unsupported fields from schema passed to OpenAI APi cleaned_response_schema = clean_response_schema(response_schema) @@ -49,7 +53,7 @@ def send_message_to_model( model_kwargs["text"] = { "format": { "type": "json_schema", - "strict": True, + "strict": strict, "name": response_schema.__name__, "schema": cleaned_response_schema, } @@ -60,7 +64,7 @@ def send_message_to_model( "json_schema": { "schema": cleaned_response_schema, "name": response_schema.__name__, - "strict": True, + "strict": strict, }, } elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT: diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 71840328..623b8fe9 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -795,7 +795,7 @@ def format_message_for_api(raw_messages: List[ChatMessage], model_name: str, api if ( part.get("type") == "text" and message.role == "assistant" - and api_base_url.startswith("https://api.deepinfra.com/v1") + and (api_base_url.startswith("https://api.deepinfra.com/v1") or is_cerebras_api(api_base_url)) ): assistant_texts += [part["text"]] message.content.pop(idx) @@ -876,6 +876,13 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo ) +def is_cerebras_api(api_base_url: str = None) -> bool: + """ + Check if the model is served over the Cerebras API + """ + return api_base_url is not None and api_base_url.startswith("https://api.cerebras.ai/v1") + + def is_groq_api(api_base_url: str = None) -> bool: """ Check if the model is served over the Groq API @@ -1212,7 +1219,7 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None: break -def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> List[Dict] | None: +def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool, strict: bool) -> List[Dict] | None: "Transform tool definitions from standard format to OpenAI format." if use_responses_api: openai_tools = [ @@ -1221,7 +1228,7 @@ def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> Lis "name": tool.name, "description": tool.description, "parameters": clean_response_schema(tool.schema), - "strict": True, + "strict": strict, } for tool in tools ] @@ -1233,7 +1240,7 @@ def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> Lis "name": tool.name, "description": tool.description, "parameters": clean_response_schema(tool.schema), - "strict": True, + "strict": strict, }, } for tool in tools