diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 0435bf7f..42586f76 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -8,9 +8,9 @@ from khoj.processor.conversation.openai.utils import ( clean_response_schema, completion_with_backoff, get_structured_output_support, - is_openai_api, responses_chat_completion_with_backoff, responses_completion_with_backoff, + supports_responses_api, to_openai_tools, ) from khoj.processor.conversation.utils import ( @@ -41,11 +41,11 @@ def send_message_to_model( model_kwargs: Dict[str, Any] = {} json_support = get_structured_output_support(model, api_base_url) if tools and json_support == StructuredOutputSupport.TOOL: - model_kwargs["tools"] = to_openai_tools(tools, use_responses_api=is_openai_api(api_base_url)) + model_kwargs["tools"] = to_openai_tools(tools, use_responses_api=supports_responses_api(model, api_base_url)) elif response_schema and json_support >= StructuredOutputSupport.SCHEMA: # Drop unsupported fields from schema passed to OpenAI APi cleaned_response_schema = clean_response_schema(response_schema) - if is_openai_api(api_base_url): + if supports_responses_api(model, api_base_url): model_kwargs["text"] = { "format": { "type": "json_schema", @@ -67,7 +67,7 @@ def send_message_to_model( model_kwargs["response_format"] = {"type": response_type} # Get Response from GPT - if is_openai_api(api_base_url): + if supports_responses_api(model, api_base_url): return responses_completion_with_backoff( messages=messages, model_name=model, @@ -106,7 +106,7 @@ async def converse_openai( logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") # Get Response from GPT - if is_openai_api(api_base_url): + if supports_responses_api(model, api_base_url): async for chunk in responses_chat_completion_with_backoff( messages=messages, model_name=model, diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 4a6a5292..71840328 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -111,7 +111,7 @@ def completion_with_backoff( model_kwargs["temperature"] = temperature model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95) - formatted_messages = format_message_for_api(messages, api_base_url) + formatted_messages = format_message_for_api(messages, model_name, api_base_url) # Tune reasoning models arguments if is_openai_reasoning_model(model_name, api_base_url): @@ -296,7 +296,7 @@ async def chat_completion_with_backoff( model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95) - formatted_messages = format_message_for_api(messages, api_base_url) + formatted_messages = format_message_for_api(messages, model_name, api_base_url) # Configure thinking for openai reasoning models if is_openai_reasoning_model(model_name, api_base_url): @@ -448,7 +448,7 @@ def responses_completion_with_backoff( client = get_openai_client(openai_api_key, api_base_url) openai_clients[client_key] = client - formatted_messages = format_message_for_api(messages, api_base_url) + formatted_messages = format_message_for_api(messages, model_name, api_base_url) # Move the first system message to Responses API instructions instructions: Optional[str] = None if formatted_messages and formatted_messages[0].get("role") == "system": @@ -461,8 +461,10 @@ def responses_completion_with_backoff( if is_openai_reasoning_model(model_name, api_base_url): temperature = 1 reasoning_effort = "medium" if deepthought else "low" - model_kwargs["reasoning"] = {"effort": reasoning_effort, "summary": "auto"} - model_kwargs["include"] = ["reasoning.encrypted_content"] + model_kwargs["reasoning"] = {"effort": reasoning_effort} + if is_openai_api(api_base_url): + model_kwargs["reasoning"]["summary"] = "auto" + model_kwargs["include"] = ["reasoning.encrypted_content"] # Remove unsupported params for reasoning models model_kwargs.pop("top_p", None) model_kwargs.pop("stop", None) @@ -559,7 +561,7 @@ async def responses_chat_completion_with_backoff( client = get_openai_async_client(openai_api_key, api_base_url) openai_async_clients[client_key] = client - formatted_messages = format_message_for_api(messages, api_base_url) + formatted_messages = format_message_for_api(messages, model_name, api_base_url) # Move the first system message to Responses API instructions instructions: Optional[str] = None if formatted_messages and formatted_messages[0].get("role") == "system": @@ -572,7 +574,10 @@ async def responses_chat_completion_with_backoff( if is_openai_reasoning_model(model_name, api_base_url): temperature = 1 reasoning_effort = "medium" if deepthought else "low" - model_kwargs["reasoning"] = {"effort": reasoning_effort, "summary": "auto"} + model_kwargs["reasoning"] = {"effort": reasoning_effort} + if is_openai_api(api_base_url): + model_kwargs["reasoning"]["summary"] = "auto" + model_kwargs["include"] = ["reasoning.encrypted_content"] # Remove unsupported params for reasoning models model_kwargs.pop("top_p", None) model_kwargs.pop("stop", None) @@ -705,7 +710,7 @@ def get_structured_output_support(model_name: str, api_base_url: str = None) -> return StructuredOutputSupport.TOOL -def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) -> List[dict]: +def format_message_for_api(raw_messages: List[ChatMessage], model_name: str, api_base_url: str) -> List[dict]: """ Format messages to send to chat model served over OpenAI (compatible) API. """ @@ -715,7 +720,7 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) - # Handle tool call and tool result message types message_type = message.additional_kwargs.get("message_type") if message_type == "tool_call": - if is_openai_api(api_base_url): + if supports_responses_api(model_name, api_base_url): for part in message.content: if "status" in part: part.pop("status") # Drop unsupported tool call status field @@ -759,7 +764,7 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) - if not tool_call_id: logger.warning(f"Dropping tool result without valid tool_call_id: {part.get('name')}") continue - if is_openai_api(api_base_url): + if supports_responses_api(model_name, api_base_url): formatted_messages.append( { "type": "function_call_output", @@ -777,7 +782,7 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) - } ) continue - if isinstance(message.content, list) and not is_openai_api(api_base_url): + if isinstance(message.content, list) and not supports_responses_api(model_name, api_base_url): assistant_texts = [] has_images = False for idx, part in enumerate(message.content): @@ -833,6 +838,13 @@ def is_openai_api(api_base_url: str = None) -> bool: return api_base_url is None or api_base_url.startswith("https://api.openai.com/v1") +def supports_responses_api(model_name: str, api_base_url: str = None) -> bool: + """ + Check if the model, ai api supports the OpenAI Responses API + """ + return is_openai_api(api_base_url) + + def is_openai_reasoning_model(model_name: str, api_base_url: str = None) -> bool: """ Check if the model is an OpenAI reasoning model