From b2d26088dcfd42b9696de863456cec45c1498ca2 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 9 Aug 2025 00:10:34 -0700 Subject: [PATCH] Use openai responses api to interact with official openai models What - Get reasoning of openai reasoning models from responses api for sho - Improves cache hits and reasoning reuse for iterative agents like research mode. This should improve speed, quality, cost and transparency of using openai reasoning models. More cache hits and better reasoning as reasoning blocks are included while model is researching (reasoning intersperse with tool calls) when using the responses api. --- src/khoj/processor/conversation/openai/gpt.py | 93 +++-- .../processor/conversation/openai/utils.py | 369 +++++++++++++++++- src/khoj/processor/conversation/utils.py | 3 + src/khoj/utils/constants.py | 3 + 4 files changed, 423 insertions(+), 45 deletions(-) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index e68ebd20..5127a027 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -9,6 +9,9 @@ from khoj.processor.conversation.openai.utils import ( clean_response_schema, completion_with_backoff, get_structured_output_support, + is_openai_api, + responses_chat_completion_with_backoff, + responses_completion_with_backoff, to_openai_tools, ) from khoj.processor.conversation.utils import ( @@ -43,31 +46,52 @@ def send_message_to_model( model_kwargs: Dict[str, Any] = {} json_support = get_structured_output_support(model, api_base_url) if tools and json_support == StructuredOutputSupport.TOOL: - model_kwargs["tools"] = to_openai_tools(tools) + model_kwargs["tools"] = to_openai_tools(tools, use_responses_api=is_openai_api(api_base_url)) elif response_schema and json_support >= StructuredOutputSupport.SCHEMA: # Drop unsupported fields from schema passed to OpenAI APi cleaned_response_schema = clean_response_schema(response_schema) - model_kwargs["response_format"] = { - "type": "json_schema", - "json_schema": { - "schema": cleaned_response_schema, - "name": response_schema.__name__, - "strict": True, - }, - } + if is_openai_api(api_base_url): + model_kwargs["text"] = { + "format": { + "type": "json_schema", + "strict": True, + "name": response_schema.__name__, + "schema": cleaned_response_schema, + } + } + else: + model_kwargs["response_format"] = { + "type": "json_schema", + "json_schema": { + "schema": cleaned_response_schema, + "name": response_schema.__name__, + "strict": True, + }, + } elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT: model_kwargs["response_format"] = {"type": response_type} # Get Response from GPT - return completion_with_backoff( - messages=messages, - model_name=model, - openai_api_key=api_key, - api_base_url=api_base_url, - deepthought=deepthought, - model_kwargs=model_kwargs, - tracer=tracer, - ) + if is_openai_api(api_base_url): + return responses_completion_with_backoff( + messages=messages, + model_name=model, + openai_api_key=api_key, + api_base_url=api_base_url, + deepthought=deepthought, + model_kwargs=model_kwargs, + tracer=tracer, + ) + else: + return completion_with_backoff( + messages=messages, + model_name=model, + openai_api_key=api_key, + api_base_url=api_base_url, + deepthought=deepthought, + model_kwargs=model_kwargs, + tracer=tracer, + ) async def converse_openai( @@ -163,13 +187,26 @@ async def converse_openai( logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") # Get Response from GPT - async for chunk in chat_completion_with_backoff( - messages=messages, - model_name=model, - temperature=temperature, - openai_api_key=api_key, - api_base_url=api_base_url, - deepthought=deepthought, - tracer=tracer, - ): - yield chunk + if is_openai_api(api_base_url): + async for chunk in responses_chat_completion_with_backoff( + messages=messages, + model_name=model, + temperature=temperature, + openai_api_key=api_key, + api_base_url=api_base_url, + deepthought=deepthought, + tracer=tracer, + ): + yield chunk + else: + # For non-OpenAI APIs, use the chat completion method + async for chunk in chat_completion_with_backoff( + messages=messages, + model_name=model, + temperature=temperature, + openai_api_key=api_key, + api_base_url=api_base_url, + deepthought=deepthought, + tracer=tracer, + ): + yield chunk diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 6340d476..27b5d9cd 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -21,6 +21,8 @@ from openai.types.chat.chat_completion_chunk import ( Choice, ChoiceDelta, ) +from openai.types.responses import Response as OpenAIResponse +from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem from pydantic import BaseModel from tenacity import ( before_sleep_log, @@ -53,6 +55,26 @@ openai_clients: Dict[str, openai.OpenAI] = {} openai_async_clients: Dict[str, openai.AsyncOpenAI] = {} +def _extract_text_for_instructions(content: Union[str, List, Dict, None]) -> str: + """Extract plain text from a message content suitable for Responses API instructions.""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + texts: List[str] = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "input_text" and part.get("text"): + texts.append(str(part.get("text"))) + return "\n\n".join(texts) + if isinstance(content, dict): + # If a single part dict was passed + if content.get("type") == "input_text" and content.get("text"): + return str(content.get("text")) + # Fallback to string conversion + return str(content) + + @retry( retry=( retry_if_exception_type(openai._exceptions.APITimeoutError) @@ -390,6 +412,287 @@ async def chat_completion_with_backoff( commit_conversation_trace(messages, aggregated_response, tracer) +@retry( + retry=( + retry_if_exception_type(openai._exceptions.APITimeoutError) + | retry_if_exception_type(openai._exceptions.APIError) + | retry_if_exception_type(openai._exceptions.APIConnectionError) + | retry_if_exception_type(openai._exceptions.RateLimitError) + | retry_if_exception_type(openai._exceptions.APIStatusError) + | retry_if_exception_type(ValueError) + ), + wait=wait_random_exponential(min=1, max=10), + stop=stop_after_attempt(3), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) +def responses_completion_with_backoff( + messages: List[ChatMessage], + model_name: str, + temperature=0.6, + openai_api_key=None, + api_base_url=None, + deepthought: bool = False, + model_kwargs: dict = {}, + tracer: dict = {}, +) -> ResponseWithThought: + """ + Synchronous helper using the OpenAI Responses API in streaming mode under the hood. + Aggregates streamed deltas and returns a ResponseWithThought. + """ + client_key = f"{openai_api_key}--{api_base_url}" + client = openai_clients.get(client_key) + if not client: + client = get_openai_client(openai_api_key, api_base_url) + openai_clients[client_key] = client + + formatted_messages = format_message_for_api(messages, api_base_url) + # Move the first system message to Responses API instructions + instructions: Optional[str] = None + if formatted_messages and formatted_messages[0].get("role") == "system": + instructions = _extract_text_for_instructions(formatted_messages[0].get("content")) or None + formatted_messages = formatted_messages[1:] + + model_kwargs = deepcopy(model_kwargs) + model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95) + # Configure thinking for openai reasoning models + if is_openai_reasoning_model(model_name, api_base_url): + temperature = 1 + reasoning_effort = "medium" if deepthought else "low" + model_kwargs["reasoning"] = {"effort": reasoning_effort, "summary": "auto"} + # Remove unsupported params for reasoning models + model_kwargs.pop("top_p", None) + model_kwargs.pop("stop", None) + + read_timeout = 300 if is_local_api(api_base_url) else 60 + + # Stream and aggregate + model_response: OpenAIResponse = client.responses.create( + input=formatted_messages, + instructions=instructions, + model=model_name, + temperature=temperature, + timeout=httpx.Timeout(30, read=read_timeout), # type: ignore + store=False, + include=["reasoning.encrypted_content"], + **model_kwargs, + ) + if not model_response or not isinstance(model_response, OpenAIResponse) or not model_response.output: + raise ValueError(f"Empty response returned by {model_name}.") + + raw_content = [item.model_dump() for item in model_response.output] + aggregated_text = model_response.output_text + thoughts = "" + tool_calls: List[ToolCall] = [] + for item in model_response.output: + if isinstance(item, ResponseFunctionToolCall): + tool_calls.append(ToolCall(name=item.name, args=json.loads(item.arguments), id=item.call_id)) + elif isinstance(item, ResponseReasoningItem): + thoughts = "\n\n".join([summary.text for summary in item.summary]) + + if tool_calls: + if thoughts and aggregated_text: + # If there are tool calls, aggregate thoughts and responses into thoughts + thoughts = "\n".join([f"*{line.strip()}*" for line in thoughts.splitlines() if line.strip()]) + thoughts = f"{thoughts}\n\n{aggregated_text}" + else: + thoughts = thoughts or aggregated_text + # Json dump tool calls into aggregated response + aggregated_text = json.dumps([tool_call.__dict__ for tool_call in tool_calls]) + + # Usage/cost tracking + input_tokens = model_response.usage.input_tokens if model_response and model_response.usage else 0 + output_tokens = model_response.usage.output_tokens if model_response and model_response.usage else 0 + cost = 0 + cache_read_tokens = 0 + if model_response and model_response.usage and model_response.usage.input_tokens_details: + cache_read_tokens = model_response.usage.input_tokens_details.cached_tokens + input_tokens -= cache_read_tokens + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, cache_read_tokens, usage=tracer.get("usage"), cost=cost + ) + + # Validate final aggregated text (either message or tool-calls JSON) + if is_none_or_empty(aggregated_text): + logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.") + raise ValueError(f"Empty or no response by {model_name} over Responses API. Retry if needed.") + + # Trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if is_promptrace_enabled(): + commit_conversation_trace(messages, aggregated_text, tracer) + + return ResponseWithThought(text=aggregated_text, thought=thoughts, raw_content=raw_content) + + +@retry( + retry=( + retry_if_exception_type(openai._exceptions.APITimeoutError) + | retry_if_exception_type(openai._exceptions.APIError) + | retry_if_exception_type(openai._exceptions.APIConnectionError) + | retry_if_exception_type(openai._exceptions.RateLimitError) + | retry_if_exception_type(openai._exceptions.APIStatusError) + | retry_if_exception_type(ValueError) + ), + wait=wait_exponential(multiplier=1, min=4, max=10), + stop=stop_after_attempt(3), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=False, +) +async def responses_chat_completion_with_backoff( + messages: list[ChatMessage], + model_name: str, + temperature, + openai_api_key=None, + api_base_url=None, + deepthought=False, # Unused; parity with legacy signature + tracer: dict = {}, +) -> AsyncGenerator[ResponseWithThought, None]: + """ + Async streaming helper using the OpenAI Responses API. + Yields ResponseWithThought chunks as text/think deltas arrive. + """ + client_key = f"{openai_api_key}--{api_base_url}" + client = openai_async_clients.get(client_key) + if not client: + client = get_openai_async_client(openai_api_key, api_base_url) + openai_async_clients[client_key] = client + + formatted_messages = format_message_for_api(messages, api_base_url) + # Move the first system message to Responses API instructions + instructions: Optional[str] = None + if formatted_messages and formatted_messages[0].get("role") == "system": + instructions = _extract_text_for_instructions(formatted_messages[0].get("content")) or None + formatted_messages = formatted_messages[1:] + + model_kwargs: dict = {} + model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95) + # Configure thinking for openai reasoning models + if is_openai_reasoning_model(model_name, api_base_url): + temperature = 1 + reasoning_effort = "medium" if deepthought else "low" + model_kwargs["reasoning"] = {"effort": reasoning_effort, "summary": "auto"} + # Remove unsupported params for reasoning models + model_kwargs.pop("top_p", None) + model_kwargs.pop("stop", None) + + read_timeout = 300 if is_local_api(api_base_url) else 60 + + aggregated_text = "" + last_final: Optional[OpenAIResponse] = None + # Tool call assembly buffers + tool_calls_args: Dict[str, str] = {} + tool_calls_name: Dict[str, str] = {} + tool_call_order: List[str] = [] + + async with client.responses.stream( + input=formatted_messages, + instructions=instructions, + model=model_name, + temperature=temperature, + timeout=httpx.Timeout(30, read=read_timeout), + **model_kwargs, + ) as stream: # type: ignore + async for event in stream: # type: ignore + et = getattr(event, "type", "") + if et == "response.output_text.delta": + delta = getattr(event, "delta", "") or getattr(event, "output_text", "") + if delta: + aggregated_text += delta + yield ResponseWithThought(text=delta) + elif et == "response.reasoning.delta": + delta = getattr(event, "delta", "") + if delta: + yield ResponseWithThought(thought=delta) + elif et == "response.tool_call.created": + item = getattr(event, "item", None) + tool_id = ( + getattr(event, "id", None) + or getattr(event, "tool_call_id", None) + or (getattr(item, "id", None) if item is not None else None) + ) + name = ( + getattr(event, "name", None) + or (getattr(item, "name", None) if item is not None else None) + or getattr(event, "tool_name", None) + ) + if tool_id: + if tool_id not in tool_calls_args: + tool_calls_args[tool_id] = "" + tool_call_order.append(tool_id) + if name: + tool_calls_name[tool_id] = name + elif et == "response.tool_call.delta": + tool_id = getattr(event, "id", None) or getattr(event, "tool_call_id", None) + delta = getattr(event, "delta", None) + if hasattr(delta, "arguments"): + arg_delta = getattr(delta, "arguments", "") + else: + arg_delta = delta if isinstance(delta, str) else getattr(event, "arguments", "") + if tool_id and arg_delta: + tool_calls_args[tool_id] = tool_calls_args.get(tool_id, "") + arg_delta + if tool_id not in tool_call_order: + tool_call_order.append(tool_id) + elif et == "response.tool_call.completed": + item = getattr(event, "item", None) + tool_id = ( + getattr(event, "id", None) + or getattr(event, "tool_call_id", None) + or (getattr(item, "id", None) if item is not None else None) + ) + args_final = None + if item is not None: + args_final = getattr(item, "arguments", None) or getattr(item, "args", None) + if tool_id and args_final: + tool_calls_args[tool_id] = args_final if isinstance(args_final, str) else json.dumps(args_final) + if tool_id not in tool_call_order: + tool_call_order.append(tool_id) + # ignore other events for now + last_final = await stream.get_final_response() + + # Usage/cost tracking after stream ends + input_tokens = last_final.usage.input_tokens if last_final and last_final.usage else 0 + output_tokens = last_final.usage.output_tokens if last_final and last_final.usage else 0 + cost = 0 + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost + ) + + # If there are tool calls, package them into aggregated text for tracing parity + if tool_call_order: + packaged_tool_calls: List[ToolCall] = [] + for tool_id in tool_call_order: + name = tool_calls_name.get(tool_id) or "" + args_str = tool_calls_args.get(tool_id, "") + try: + args = json.loads(args_str) if isinstance(args_str, str) else args_str + except Exception: + logger.warning(f"Failed to parse tool call arguments for {tool_id}: {args_str}") + args = {} + packaged_tool_calls.append(ToolCall(name=name, args=args, id=tool_id)) + # Move any text into trace thought + tracer_text = aggregated_text + aggregated_text = json.dumps([tc.__dict__ for tc in packaged_tool_calls]) + # Save for trace below + if tracer_text: + tracer.setdefault("_responses_stream_text", tracer_text) + + if is_none_or_empty(aggregated_text): + logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.") + raise ValueError(f"Empty or no response by {model_name} over Responses API. Retry if needed.") + + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if is_promptrace_enabled(): + # If tool-calls were present, include any streamed text in the trace thought + trace_payload = aggregated_text + if tracer.get("_responses_stream_text"): + thoughts = tracer.pop("_responses_stream_text") + trace_payload = thoughts + commit_conversation_trace(messages, trace_payload, tracer) + + def get_structured_output_support(model_name: str, api_base_url: str = None) -> StructuredOutputSupport: if model_name.startswith("deepseek-reasoner"): return StructuredOutputSupport.NONE @@ -412,6 +715,12 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) - # Handle tool call and tool result message types message_type = message.additional_kwargs.get("message_type") if message_type == "tool_call": + if is_openai_api(api_base_url): + for part in message.content: + if "status" in part: + part.pop("status") # Drop unsupported tool call status field + formatted_messages.extend(message.content) + continue # Convert tool_call to OpenAI function call format content = [] for part in message.content: @@ -450,14 +759,23 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) - if not tool_call_id: logger.warning(f"Dropping tool result without valid tool_call_id: {part.get('name')}") continue - formatted_messages.append( - { - "role": "tool", - "tool_call_id": tool_call_id, - "name": part.get("name"), - "content": part.get("content"), - } - ) + if is_openai_api(api_base_url): + formatted_messages.append( + { + "type": "function_call_output", + "call_id": tool_call_id, + "output": part.get("content"), + } + ) + else: + formatted_messages.append( + { + "role": "tool", + "tool_call_id": tool_call_id, + "name": part.get("name"), + "content": part.get("content"), + } + ) continue if isinstance(message.content, list) and not is_openai_api(api_base_url): assistant_texts = [] @@ -489,6 +807,11 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) - message.content.remove(part) elif part["type"] == "image_url" and not part.get("image_url"): message.content.remove(part) + # OpenAI models use the Responses API which uses slightly different content types + if part["type"] == "text": + part["type"] = "output_text" if message.role == "assistant" else "input_text" + if part["type"] == "image": + part["type"] = "output_image" if message.role == "assistant" else "input_image" # If no valid content parts left, remove the message if is_none_or_empty(message.content): messages.remove(message) @@ -852,20 +1175,32 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None: break -def to_openai_tools(tools: List[ToolDefinition]) -> List[Dict] | None: +def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> List[Dict] | None: "Transform tool definitions from standard format to OpenAI format." - openai_tools = [ - { - "type": "function", - "function": { + if use_responses_api: + openai_tools = [ + { + "type": "function", "name": tool.name, "description": tool.description, "parameters": clean_response_schema(tool.schema), "strict": True, - }, - } - for tool in tools - ] + } + for tool in tools + ] + else: + openai_tools = [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": clean_response_schema(tool.schema), + "strict": True, + }, + } + for tool in tools + ] return openai_tools or None diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 1b94ba28..63cf2bf5 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -68,6 +68,9 @@ model_to_prompt_size = { "o3": 60000, "o3-pro": 30000, "o4-mini": 90000, + "gpt-5-2025-08-07": 120000, + "gpt-5-mini-2025-08-07": 120000, + "gpt-5-nano-2025-08-07": 120000, # Google Models "gemini-2.5-flash": 120000, "gemini-2.5-pro": 60000, diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index b2ed49de..d0b52f8b 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -40,6 +40,9 @@ model_to_cost: Dict[str, Dict[str, float]] = { "o3": {"input": 2.0, "output": 8.00}, "o3-pro": {"input": 20.0, "output": 80.00}, "o4-mini": {"input": 1.10, "output": 4.40}, + "gpt-5-2025-08-07": {"input": 1.25, "output": 10.00, "cache_read": 0.125}, + "gpt-5-mini-2025-08-07": {"input": 0.25, "output": 2.00, "cache_read": 0.025}, + "gpt-5-nano-2025-08-07": {"input": 0.05, "output": 0.40, "cache_read": 0.005}, # Gemini Pricing: https://ai.google.dev/pricing "gemini-1.5-flash": {"input": 0.075, "output": 0.30}, "gemini-1.5-flash-002": {"input": 0.075, "output": 0.30},