diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index e68ebd20..5127a027 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -9,6 +9,9 @@ from khoj.processor.conversation.openai.utils import ( clean_response_schema, completion_with_backoff, get_structured_output_support, + is_openai_api, + responses_chat_completion_with_backoff, + responses_completion_with_backoff, to_openai_tools, ) from khoj.processor.conversation.utils import ( @@ -43,31 +46,52 @@ def send_message_to_model( model_kwargs: Dict[str, Any] = {} json_support = get_structured_output_support(model, api_base_url) if tools and json_support == StructuredOutputSupport.TOOL: - model_kwargs["tools"] = to_openai_tools(tools) + model_kwargs["tools"] = to_openai_tools(tools, use_responses_api=is_openai_api(api_base_url)) elif response_schema and json_support >= StructuredOutputSupport.SCHEMA: # Drop unsupported fields from schema passed to OpenAI APi cleaned_response_schema = clean_response_schema(response_schema) - model_kwargs["response_format"] = { - "type": "json_schema", - "json_schema": { - "schema": cleaned_response_schema, - "name": response_schema.__name__, - "strict": True, - }, - } + if is_openai_api(api_base_url): + model_kwargs["text"] = { + "format": { + "type": "json_schema", + "strict": True, + "name": response_schema.__name__, + "schema": cleaned_response_schema, + } + } + else: + model_kwargs["response_format"] = { + "type": "json_schema", + "json_schema": { + "schema": cleaned_response_schema, + "name": response_schema.__name__, + "strict": True, + }, + } elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT: model_kwargs["response_format"] = {"type": response_type} # Get Response from GPT - return completion_with_backoff( - messages=messages, - model_name=model, - openai_api_key=api_key, - api_base_url=api_base_url, - deepthought=deepthought, - model_kwargs=model_kwargs, - tracer=tracer, - ) + if is_openai_api(api_base_url): + return responses_completion_with_backoff( + messages=messages, + model_name=model, + openai_api_key=api_key, + api_base_url=api_base_url, + deepthought=deepthought, + model_kwargs=model_kwargs, + tracer=tracer, + ) + else: + return completion_with_backoff( + messages=messages, + model_name=model, + openai_api_key=api_key, + api_base_url=api_base_url, + deepthought=deepthought, + model_kwargs=model_kwargs, + tracer=tracer, + ) async def converse_openai( @@ -163,13 +187,26 @@ async def converse_openai( logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") # Get Response from GPT - async for chunk in chat_completion_with_backoff( - messages=messages, - model_name=model, - temperature=temperature, - openai_api_key=api_key, - api_base_url=api_base_url, - deepthought=deepthought, - tracer=tracer, - ): - yield chunk + if is_openai_api(api_base_url): + async for chunk in responses_chat_completion_with_backoff( + messages=messages, + model_name=model, + temperature=temperature, + openai_api_key=api_key, + api_base_url=api_base_url, + deepthought=deepthought, + tracer=tracer, + ): + yield chunk + else: + # For non-OpenAI APIs, use the chat completion method + async for chunk in chat_completion_with_backoff( + messages=messages, + model_name=model, + temperature=temperature, + openai_api_key=api_key, + api_base_url=api_base_url, + deepthought=deepthought, + tracer=tracer, + ): + yield chunk diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 6340d476..27b5d9cd 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -21,6 +21,8 @@ from openai.types.chat.chat_completion_chunk import ( Choice, ChoiceDelta, ) +from openai.types.responses import Response as OpenAIResponse +from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem from pydantic import BaseModel from tenacity import ( before_sleep_log, @@ -53,6 +55,26 @@ openai_clients: Dict[str, openai.OpenAI] = {} openai_async_clients: Dict[str, openai.AsyncOpenAI] = {} +def _extract_text_for_instructions(content: Union[str, List, Dict, None]) -> str: + """Extract plain text from a message content suitable for Responses API instructions.""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + texts: List[str] = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "input_text" and part.get("text"): + texts.append(str(part.get("text"))) + return "\n\n".join(texts) + if isinstance(content, dict): + # If a single part dict was passed + if content.get("type") == "input_text" and content.get("text"): + return str(content.get("text")) + # Fallback to string conversion + return str(content) + + @retry( retry=( retry_if_exception_type(openai._exceptions.APITimeoutError) @@ -390,6 +412,287 @@ async def chat_completion_with_backoff( commit_conversation_trace(messages, aggregated_response, tracer) +@retry( + retry=( + retry_if_exception_type(openai._exceptions.APITimeoutError) + | retry_if_exception_type(openai._exceptions.APIError) + | retry_if_exception_type(openai._exceptions.APIConnectionError) + | retry_if_exception_type(openai._exceptions.RateLimitError) + | retry_if_exception_type(openai._exceptions.APIStatusError) + | retry_if_exception_type(ValueError) + ), + wait=wait_random_exponential(min=1, max=10), + stop=stop_after_attempt(3), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) +def responses_completion_with_backoff( + messages: List[ChatMessage], + model_name: str, + temperature=0.6, + openai_api_key=None, + api_base_url=None, + deepthought: bool = False, + model_kwargs: dict = {}, + tracer: dict = {}, +) -> ResponseWithThought: + """ + Synchronous helper using the OpenAI Responses API in streaming mode under the hood. + Aggregates streamed deltas and returns a ResponseWithThought. + """ + client_key = f"{openai_api_key}--{api_base_url}" + client = openai_clients.get(client_key) + if not client: + client = get_openai_client(openai_api_key, api_base_url) + openai_clients[client_key] = client + + formatted_messages = format_message_for_api(messages, api_base_url) + # Move the first system message to Responses API instructions + instructions: Optional[str] = None + if formatted_messages and formatted_messages[0].get("role") == "system": + instructions = _extract_text_for_instructions(formatted_messages[0].get("content")) or None + formatted_messages = formatted_messages[1:] + + model_kwargs = deepcopy(model_kwargs) + model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95) + # Configure thinking for openai reasoning models + if is_openai_reasoning_model(model_name, api_base_url): + temperature = 1 + reasoning_effort = "medium" if deepthought else "low" + model_kwargs["reasoning"] = {"effort": reasoning_effort, "summary": "auto"} + # Remove unsupported params for reasoning models + model_kwargs.pop("top_p", None) + model_kwargs.pop("stop", None) + + read_timeout = 300 if is_local_api(api_base_url) else 60 + + # Stream and aggregate + model_response: OpenAIResponse = client.responses.create( + input=formatted_messages, + instructions=instructions, + model=model_name, + temperature=temperature, + timeout=httpx.Timeout(30, read=read_timeout), # type: ignore + store=False, + include=["reasoning.encrypted_content"], + **model_kwargs, + ) + if not model_response or not isinstance(model_response, OpenAIResponse) or not model_response.output: + raise ValueError(f"Empty response returned by {model_name}.") + + raw_content = [item.model_dump() for item in model_response.output] + aggregated_text = model_response.output_text + thoughts = "" + tool_calls: List[ToolCall] = [] + for item in model_response.output: + if isinstance(item, ResponseFunctionToolCall): + tool_calls.append(ToolCall(name=item.name, args=json.loads(item.arguments), id=item.call_id)) + elif isinstance(item, ResponseReasoningItem): + thoughts = "\n\n".join([summary.text for summary in item.summary]) + + if tool_calls: + if thoughts and aggregated_text: + # If there are tool calls, aggregate thoughts and responses into thoughts + thoughts = "\n".join([f"*{line.strip()}*" for line in thoughts.splitlines() if line.strip()]) + thoughts = f"{thoughts}\n\n{aggregated_text}" + else: + thoughts = thoughts or aggregated_text + # Json dump tool calls into aggregated response + aggregated_text = json.dumps([tool_call.__dict__ for tool_call in tool_calls]) + + # Usage/cost tracking + input_tokens = model_response.usage.input_tokens if model_response and model_response.usage else 0 + output_tokens = model_response.usage.output_tokens if model_response and model_response.usage else 0 + cost = 0 + cache_read_tokens = 0 + if model_response and model_response.usage and model_response.usage.input_tokens_details: + cache_read_tokens = model_response.usage.input_tokens_details.cached_tokens + input_tokens -= cache_read_tokens + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, cache_read_tokens, usage=tracer.get("usage"), cost=cost + ) + + # Validate final aggregated text (either message or tool-calls JSON) + if is_none_or_empty(aggregated_text): + logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.") + raise ValueError(f"Empty or no response by {model_name} over Responses API. Retry if needed.") + + # Trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if is_promptrace_enabled(): + commit_conversation_trace(messages, aggregated_text, tracer) + + return ResponseWithThought(text=aggregated_text, thought=thoughts, raw_content=raw_content) + + +@retry( + retry=( + retry_if_exception_type(openai._exceptions.APITimeoutError) + | retry_if_exception_type(openai._exceptions.APIError) + | retry_if_exception_type(openai._exceptions.APIConnectionError) + | retry_if_exception_type(openai._exceptions.RateLimitError) + | retry_if_exception_type(openai._exceptions.APIStatusError) + | retry_if_exception_type(ValueError) + ), + wait=wait_exponential(multiplier=1, min=4, max=10), + stop=stop_after_attempt(3), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=False, +) +async def responses_chat_completion_with_backoff( + messages: list[ChatMessage], + model_name: str, + temperature, + openai_api_key=None, + api_base_url=None, + deepthought=False, # Unused; parity with legacy signature + tracer: dict = {}, +) -> AsyncGenerator[ResponseWithThought, None]: + """ + Async streaming helper using the OpenAI Responses API. + Yields ResponseWithThought chunks as text/think deltas arrive. + """ + client_key = f"{openai_api_key}--{api_base_url}" + client = openai_async_clients.get(client_key) + if not client: + client = get_openai_async_client(openai_api_key, api_base_url) + openai_async_clients[client_key] = client + + formatted_messages = format_message_for_api(messages, api_base_url) + # Move the first system message to Responses API instructions + instructions: Optional[str] = None + if formatted_messages and formatted_messages[0].get("role") == "system": + instructions = _extract_text_for_instructions(formatted_messages[0].get("content")) or None + formatted_messages = formatted_messages[1:] + + model_kwargs: dict = {} + model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95) + # Configure thinking for openai reasoning models + if is_openai_reasoning_model(model_name, api_base_url): + temperature = 1 + reasoning_effort = "medium" if deepthought else "low" + model_kwargs["reasoning"] = {"effort": reasoning_effort, "summary": "auto"} + # Remove unsupported params for reasoning models + model_kwargs.pop("top_p", None) + model_kwargs.pop("stop", None) + + read_timeout = 300 if is_local_api(api_base_url) else 60 + + aggregated_text = "" + last_final: Optional[OpenAIResponse] = None + # Tool call assembly buffers + tool_calls_args: Dict[str, str] = {} + tool_calls_name: Dict[str, str] = {} + tool_call_order: List[str] = [] + + async with client.responses.stream( + input=formatted_messages, + instructions=instructions, + model=model_name, + temperature=temperature, + timeout=httpx.Timeout(30, read=read_timeout), + **model_kwargs, + ) as stream: # type: ignore + async for event in stream: # type: ignore + et = getattr(event, "type", "") + if et == "response.output_text.delta": + delta = getattr(event, "delta", "") or getattr(event, "output_text", "") + if delta: + aggregated_text += delta + yield ResponseWithThought(text=delta) + elif et == "response.reasoning.delta": + delta = getattr(event, "delta", "") + if delta: + yield ResponseWithThought(thought=delta) + elif et == "response.tool_call.created": + item = getattr(event, "item", None) + tool_id = ( + getattr(event, "id", None) + or getattr(event, "tool_call_id", None) + or (getattr(item, "id", None) if item is not None else None) + ) + name = ( + getattr(event, "name", None) + or (getattr(item, "name", None) if item is not None else None) + or getattr(event, "tool_name", None) + ) + if tool_id: + if tool_id not in tool_calls_args: + tool_calls_args[tool_id] = "" + tool_call_order.append(tool_id) + if name: + tool_calls_name[tool_id] = name + elif et == "response.tool_call.delta": + tool_id = getattr(event, "id", None) or getattr(event, "tool_call_id", None) + delta = getattr(event, "delta", None) + if hasattr(delta, "arguments"): + arg_delta = getattr(delta, "arguments", "") + else: + arg_delta = delta if isinstance(delta, str) else getattr(event, "arguments", "") + if tool_id and arg_delta: + tool_calls_args[tool_id] = tool_calls_args.get(tool_id, "") + arg_delta + if tool_id not in tool_call_order: + tool_call_order.append(tool_id) + elif et == "response.tool_call.completed": + item = getattr(event, "item", None) + tool_id = ( + getattr(event, "id", None) + or getattr(event, "tool_call_id", None) + or (getattr(item, "id", None) if item is not None else None) + ) + args_final = None + if item is not None: + args_final = getattr(item, "arguments", None) or getattr(item, "args", None) + if tool_id and args_final: + tool_calls_args[tool_id] = args_final if isinstance(args_final, str) else json.dumps(args_final) + if tool_id not in tool_call_order: + tool_call_order.append(tool_id) + # ignore other events for now + last_final = await stream.get_final_response() + + # Usage/cost tracking after stream ends + input_tokens = last_final.usage.input_tokens if last_final and last_final.usage else 0 + output_tokens = last_final.usage.output_tokens if last_final and last_final.usage else 0 + cost = 0 + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost + ) + + # If there are tool calls, package them into aggregated text for tracing parity + if tool_call_order: + packaged_tool_calls: List[ToolCall] = [] + for tool_id in tool_call_order: + name = tool_calls_name.get(tool_id) or "" + args_str = tool_calls_args.get(tool_id, "") + try: + args = json.loads(args_str) if isinstance(args_str, str) else args_str + except Exception: + logger.warning(f"Failed to parse tool call arguments for {tool_id}: {args_str}") + args = {} + packaged_tool_calls.append(ToolCall(name=name, args=args, id=tool_id)) + # Move any text into trace thought + tracer_text = aggregated_text + aggregated_text = json.dumps([tc.__dict__ for tc in packaged_tool_calls]) + # Save for trace below + if tracer_text: + tracer.setdefault("_responses_stream_text", tracer_text) + + if is_none_or_empty(aggregated_text): + logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.") + raise ValueError(f"Empty or no response by {model_name} over Responses API. Retry if needed.") + + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if is_promptrace_enabled(): + # If tool-calls were present, include any streamed text in the trace thought + trace_payload = aggregated_text + if tracer.get("_responses_stream_text"): + thoughts = tracer.pop("_responses_stream_text") + trace_payload = thoughts + commit_conversation_trace(messages, trace_payload, tracer) + + def get_structured_output_support(model_name: str, api_base_url: str = None) -> StructuredOutputSupport: if model_name.startswith("deepseek-reasoner"): return StructuredOutputSupport.NONE @@ -412,6 +715,12 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) - # Handle tool call and tool result message types message_type = message.additional_kwargs.get("message_type") if message_type == "tool_call": + if is_openai_api(api_base_url): + for part in message.content: + if "status" in part: + part.pop("status") # Drop unsupported tool call status field + formatted_messages.extend(message.content) + continue # Convert tool_call to OpenAI function call format content = [] for part in message.content: @@ -450,14 +759,23 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) - if not tool_call_id: logger.warning(f"Dropping tool result without valid tool_call_id: {part.get('name')}") continue - formatted_messages.append( - { - "role": "tool", - "tool_call_id": tool_call_id, - "name": part.get("name"), - "content": part.get("content"), - } - ) + if is_openai_api(api_base_url): + formatted_messages.append( + { + "type": "function_call_output", + "call_id": tool_call_id, + "output": part.get("content"), + } + ) + else: + formatted_messages.append( + { + "role": "tool", + "tool_call_id": tool_call_id, + "name": part.get("name"), + "content": part.get("content"), + } + ) continue if isinstance(message.content, list) and not is_openai_api(api_base_url): assistant_texts = [] @@ -489,6 +807,11 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) - message.content.remove(part) elif part["type"] == "image_url" and not part.get("image_url"): message.content.remove(part) + # OpenAI models use the Responses API which uses slightly different content types + if part["type"] == "text": + part["type"] = "output_text" if message.role == "assistant" else "input_text" + if part["type"] == "image": + part["type"] = "output_image" if message.role == "assistant" else "input_image" # If no valid content parts left, remove the message if is_none_or_empty(message.content): messages.remove(message) @@ -852,20 +1175,32 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None: break -def to_openai_tools(tools: List[ToolDefinition]) -> List[Dict] | None: +def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> List[Dict] | None: "Transform tool definitions from standard format to OpenAI format." - openai_tools = [ - { - "type": "function", - "function": { + if use_responses_api: + openai_tools = [ + { + "type": "function", "name": tool.name, "description": tool.description, "parameters": clean_response_schema(tool.schema), "strict": True, - }, - } - for tool in tools - ] + } + for tool in tools + ] + else: + openai_tools = [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": clean_response_schema(tool.schema), + "strict": True, + }, + } + for tool in tools + ] return openai_tools or None diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 1b94ba28..63cf2bf5 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -68,6 +68,9 @@ model_to_prompt_size = { "o3": 60000, "o3-pro": 30000, "o4-mini": 90000, + "gpt-5-2025-08-07": 120000, + "gpt-5-mini-2025-08-07": 120000, + "gpt-5-nano-2025-08-07": 120000, # Google Models "gemini-2.5-flash": 120000, "gemini-2.5-pro": 60000, diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index b2ed49de..d0b52f8b 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -40,6 +40,9 @@ model_to_cost: Dict[str, Dict[str, float]] = { "o3": {"input": 2.0, "output": 8.00}, "o3-pro": {"input": 20.0, "output": 80.00}, "o4-mini": {"input": 1.10, "output": 4.40}, + "gpt-5-2025-08-07": {"input": 1.25, "output": 10.00, "cache_read": 0.125}, + "gpt-5-mini-2025-08-07": {"input": 0.25, "output": 2.00, "cache_read": 0.025}, + "gpt-5-nano-2025-08-07": {"input": 0.05, "output": 0.40, "cache_read": 0.005}, # Gemini Pricing: https://ai.google.dev/pricing "gemini-1.5-flash": {"input": 0.075, "output": 0.30}, "gemini-1.5-flash-002": {"input": 0.075, "output": 0.30},