From b335f8cf796ecc0cff89fc5d5c325009adab1f8f Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 31 Jul 2025 15:31:19 -0700 Subject: [PATCH] Support grok 4 reasoning model --- src/khoj/processor/conversation/openai/utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index f5a909d6..3b9e05a1 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -104,6 +104,9 @@ def completion_with_backoff( model_kwargs.pop("temperature", None) reasoning_effort = "high" if deepthought else "low" model_kwargs["reasoning_effort"] = reasoning_effort + if model_name.startswith("grok-4"): + # Grok-4 models do not support reasoning_effort parameter + model_kwargs.pop("reasoning_effort", None) elif model_name.startswith("deepseek-reasoner"): stream_processor = in_stream_thought_processor # Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role. @@ -280,7 +283,9 @@ async def chat_completion_with_backoff( ] = f"{first_system_message_content}\nFormatting re-enabled" elif is_twitter_reasoning_model(model_name, api_base_url): reasoning_effort = "high" if deepthought else "low" - model_kwargs["reasoning_effort"] = reasoning_effort + # Grok-4 models do not support reasoning_effort parameter + if not model_name.startswith("grok-4"): + model_kwargs["reasoning_effort"] = reasoning_effort elif model_name.startswith("deepseek-reasoner") or "deepseek-r1" in model_name: # Official Deepseek reasoner model and some inference APIs like vLLM return structured thinking output. # Others like DeepInfra return it in response stream. @@ -503,7 +508,7 @@ def is_openai_reasoning_model(model_name: str, api_base_url: str = None) -> bool """ Check if the model is an OpenAI reasoning model """ - return model_name.startswith("o") and is_openai_api(api_base_url) + return model_name.lower().startswith("o") and is_openai_api(api_base_url) def is_non_streaming_model(model_name: str, api_base_url: str = None) -> bool: @@ -518,8 +523,9 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo """ Check if the model is a Twitter reasoning model """ + reasoning_models = "grok-3-mini", "grok-4" return ( - model_name.startswith("grok-3-mini") + any(prefix in model_name.lower() for prefix in reasoning_models) and api_base_url is not None and api_base_url.startswith("https://api.x.ai/v1") )