Use openai responses api to interact with official openai models

What
- Get reasoning of openai reasoning models from responses api for sho
- Improves cache hits and reasoning reuse for iterative agents like
  research mode.

This should improve speed, quality, cost and transparency of using
openai reasoning models.

More cache hits and better reasoning as reasoning blocks are included
while model is researching (reasoning intersperse with tool calls)
when using the responses api.
This commit is contained in:
Debanjum
2025-08-09 00:10:34 -07:00
parent 564adb24a7
commit b2d26088dc
4 changed files with 423 additions and 45 deletions

View File

@@ -9,6 +9,9 @@ from khoj.processor.conversation.openai.utils import (
clean_response_schema,
completion_with_backoff,
get_structured_output_support,
is_openai_api,
responses_chat_completion_with_backoff,
responses_completion_with_backoff,
to_openai_tools,
)
from khoj.processor.conversation.utils import (
@@ -43,31 +46,52 @@ def send_message_to_model(
model_kwargs: Dict[str, Any] = {}
json_support = get_structured_output_support(model, api_base_url)
if tools and json_support == StructuredOutputSupport.TOOL:
model_kwargs["tools"] = to_openai_tools(tools)
model_kwargs["tools"] = to_openai_tools(tools, use_responses_api=is_openai_api(api_base_url))
elif response_schema and json_support >= StructuredOutputSupport.SCHEMA:
# Drop unsupported fields from schema passed to OpenAI APi
cleaned_response_schema = clean_response_schema(response_schema)
model_kwargs["response_format"] = {
"type": "json_schema",
"json_schema": {
"schema": cleaned_response_schema,
"name": response_schema.__name__,
"strict": True,
},
}
if is_openai_api(api_base_url):
model_kwargs["text"] = {
"format": {
"type": "json_schema",
"strict": True,
"name": response_schema.__name__,
"schema": cleaned_response_schema,
}
}
else:
model_kwargs["response_format"] = {
"type": "json_schema",
"json_schema": {
"schema": cleaned_response_schema,
"name": response_schema.__name__,
"strict": True,
},
}
elif response_type == "json_object" and json_support == StructuredOutputSupport.OBJECT:
model_kwargs["response_format"] = {"type": response_type}
# Get Response from GPT
return completion_with_backoff(
messages=messages,
model_name=model,
openai_api_key=api_key,
api_base_url=api_base_url,
deepthought=deepthought,
model_kwargs=model_kwargs,
tracer=tracer,
)
if is_openai_api(api_base_url):
return responses_completion_with_backoff(
messages=messages,
model_name=model,
openai_api_key=api_key,
api_base_url=api_base_url,
deepthought=deepthought,
model_kwargs=model_kwargs,
tracer=tracer,
)
else:
return completion_with_backoff(
messages=messages,
model_name=model,
openai_api_key=api_key,
api_base_url=api_base_url,
deepthought=deepthought,
model_kwargs=model_kwargs,
tracer=tracer,
)
async def converse_openai(
@@ -163,13 +187,26 @@ async def converse_openai(
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
# Get Response from GPT
async for chunk in chat_completion_with_backoff(
messages=messages,
model_name=model,
temperature=temperature,
openai_api_key=api_key,
api_base_url=api_base_url,
deepthought=deepthought,
tracer=tracer,
):
yield chunk
if is_openai_api(api_base_url):
async for chunk in responses_chat_completion_with_backoff(
messages=messages,
model_name=model,
temperature=temperature,
openai_api_key=api_key,
api_base_url=api_base_url,
deepthought=deepthought,
tracer=tracer,
):
yield chunk
else:
# For non-OpenAI APIs, use the chat completion method
async for chunk in chat_completion_with_backoff(
messages=messages,
model_name=model,
temperature=temperature,
openai_api_key=api_key,
api_base_url=api_base_url,
deepthought=deepthought,
tracer=tracer,
):
yield chunk

View File

@@ -21,6 +21,8 @@ from openai.types.chat.chat_completion_chunk import (
Choice,
ChoiceDelta,
)
from openai.types.responses import Response as OpenAIResponse
from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem
from pydantic import BaseModel
from tenacity import (
before_sleep_log,
@@ -53,6 +55,26 @@ openai_clients: Dict[str, openai.OpenAI] = {}
openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
def _extract_text_for_instructions(content: Union[str, List, Dict, None]) -> str:
"""Extract plain text from a message content suitable for Responses API instructions."""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
texts: List[str] = []
for part in content:
if isinstance(part, dict) and part.get("type") == "input_text" and part.get("text"):
texts.append(str(part.get("text")))
return "\n\n".join(texts)
if isinstance(content, dict):
# If a single part dict was passed
if content.get("type") == "input_text" and content.get("text"):
return str(content.get("text"))
# Fallback to string conversion
return str(content)
@retry(
retry=(
retry_if_exception_type(openai._exceptions.APITimeoutError)
@@ -390,6 +412,287 @@ async def chat_completion_with_backoff(
commit_conversation_trace(messages, aggregated_response, tracer)
@retry(
retry=(
retry_if_exception_type(openai._exceptions.APITimeoutError)
| retry_if_exception_type(openai._exceptions.APIError)
| retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai._exceptions.APIStatusError)
| retry_if_exception_type(ValueError)
),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def responses_completion_with_backoff(
messages: List[ChatMessage],
model_name: str,
temperature=0.6,
openai_api_key=None,
api_base_url=None,
deepthought: bool = False,
model_kwargs: dict = {},
tracer: dict = {},
) -> ResponseWithThought:
"""
Synchronous helper using the OpenAI Responses API in streaming mode under the hood.
Aggregates streamed deltas and returns a ResponseWithThought.
"""
client_key = f"{openai_api_key}--{api_base_url}"
client = openai_clients.get(client_key)
if not client:
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client
formatted_messages = format_message_for_api(messages, api_base_url)
# Move the first system message to Responses API instructions
instructions: Optional[str] = None
if formatted_messages and formatted_messages[0].get("role") == "system":
instructions = _extract_text_for_instructions(formatted_messages[0].get("content")) or None
formatted_messages = formatted_messages[1:]
model_kwargs = deepcopy(model_kwargs)
model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95)
# Configure thinking for openai reasoning models
if is_openai_reasoning_model(model_name, api_base_url):
temperature = 1
reasoning_effort = "medium" if deepthought else "low"
model_kwargs["reasoning"] = {"effort": reasoning_effort, "summary": "auto"}
# Remove unsupported params for reasoning models
model_kwargs.pop("top_p", None)
model_kwargs.pop("stop", None)
read_timeout = 300 if is_local_api(api_base_url) else 60
# Stream and aggregate
model_response: OpenAIResponse = client.responses.create(
input=formatted_messages,
instructions=instructions,
model=model_name,
temperature=temperature,
timeout=httpx.Timeout(30, read=read_timeout), # type: ignore
store=False,
include=["reasoning.encrypted_content"],
**model_kwargs,
)
if not model_response or not isinstance(model_response, OpenAIResponse) or not model_response.output:
raise ValueError(f"Empty response returned by {model_name}.")
raw_content = [item.model_dump() for item in model_response.output]
aggregated_text = model_response.output_text
thoughts = ""
tool_calls: List[ToolCall] = []
for item in model_response.output:
if isinstance(item, ResponseFunctionToolCall):
tool_calls.append(ToolCall(name=item.name, args=json.loads(item.arguments), id=item.call_id))
elif isinstance(item, ResponseReasoningItem):
thoughts = "\n\n".join([summary.text for summary in item.summary])
if tool_calls:
if thoughts and aggregated_text:
# If there are tool calls, aggregate thoughts and responses into thoughts
thoughts = "\n".join([f"*{line.strip()}*" for line in thoughts.splitlines() if line.strip()])
thoughts = f"{thoughts}\n\n{aggregated_text}"
else:
thoughts = thoughts or aggregated_text
# Json dump tool calls into aggregated response
aggregated_text = json.dumps([tool_call.__dict__ for tool_call in tool_calls])
# Usage/cost tracking
input_tokens = model_response.usage.input_tokens if model_response and model_response.usage else 0
output_tokens = model_response.usage.output_tokens if model_response and model_response.usage else 0
cost = 0
cache_read_tokens = 0
if model_response and model_response.usage and model_response.usage.input_tokens_details:
cache_read_tokens = model_response.usage.input_tokens_details.cached_tokens
input_tokens -= cache_read_tokens
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, cache_read_tokens, usage=tracer.get("usage"), cost=cost
)
# Validate final aggregated text (either message or tool-calls JSON)
if is_none_or_empty(aggregated_text):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over Responses API. Retry if needed.")
# Trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_text, tracer)
return ResponseWithThought(text=aggregated_text, thought=thoughts, raw_content=raw_content)
@retry(
retry=(
retry_if_exception_type(openai._exceptions.APITimeoutError)
| retry_if_exception_type(openai._exceptions.APIError)
| retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai._exceptions.APIStatusError)
| retry_if_exception_type(ValueError)
),
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=False,
)
async def responses_chat_completion_with_backoff(
messages: list[ChatMessage],
model_name: str,
temperature,
openai_api_key=None,
api_base_url=None,
deepthought=False, # Unused; parity with legacy signature
tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]:
"""
Async streaming helper using the OpenAI Responses API.
Yields ResponseWithThought chunks as text/think deltas arrive.
"""
client_key = f"{openai_api_key}--{api_base_url}"
client = openai_async_clients.get(client_key)
if not client:
client = get_openai_async_client(openai_api_key, api_base_url)
openai_async_clients[client_key] = client
formatted_messages = format_message_for_api(messages, api_base_url)
# Move the first system message to Responses API instructions
instructions: Optional[str] = None
if formatted_messages and formatted_messages[0].get("role") == "system":
instructions = _extract_text_for_instructions(formatted_messages[0].get("content")) or None
formatted_messages = formatted_messages[1:]
model_kwargs: dict = {}
model_kwargs["top_p"] = model_kwargs.get("top_p", 0.95)
# Configure thinking for openai reasoning models
if is_openai_reasoning_model(model_name, api_base_url):
temperature = 1
reasoning_effort = "medium" if deepthought else "low"
model_kwargs["reasoning"] = {"effort": reasoning_effort, "summary": "auto"}
# Remove unsupported params for reasoning models
model_kwargs.pop("top_p", None)
model_kwargs.pop("stop", None)
read_timeout = 300 if is_local_api(api_base_url) else 60
aggregated_text = ""
last_final: Optional[OpenAIResponse] = None
# Tool call assembly buffers
tool_calls_args: Dict[str, str] = {}
tool_calls_name: Dict[str, str] = {}
tool_call_order: List[str] = []
async with client.responses.stream(
input=formatted_messages,
instructions=instructions,
model=model_name,
temperature=temperature,
timeout=httpx.Timeout(30, read=read_timeout),
**model_kwargs,
) as stream: # type: ignore
async for event in stream: # type: ignore
et = getattr(event, "type", "")
if et == "response.output_text.delta":
delta = getattr(event, "delta", "") or getattr(event, "output_text", "")
if delta:
aggregated_text += delta
yield ResponseWithThought(text=delta)
elif et == "response.reasoning.delta":
delta = getattr(event, "delta", "")
if delta:
yield ResponseWithThought(thought=delta)
elif et == "response.tool_call.created":
item = getattr(event, "item", None)
tool_id = (
getattr(event, "id", None)
or getattr(event, "tool_call_id", None)
or (getattr(item, "id", None) if item is not None else None)
)
name = (
getattr(event, "name", None)
or (getattr(item, "name", None) if item is not None else None)
or getattr(event, "tool_name", None)
)
if tool_id:
if tool_id not in tool_calls_args:
tool_calls_args[tool_id] = ""
tool_call_order.append(tool_id)
if name:
tool_calls_name[tool_id] = name
elif et == "response.tool_call.delta":
tool_id = getattr(event, "id", None) or getattr(event, "tool_call_id", None)
delta = getattr(event, "delta", None)
if hasattr(delta, "arguments"):
arg_delta = getattr(delta, "arguments", "")
else:
arg_delta = delta if isinstance(delta, str) else getattr(event, "arguments", "")
if tool_id and arg_delta:
tool_calls_args[tool_id] = tool_calls_args.get(tool_id, "") + arg_delta
if tool_id not in tool_call_order:
tool_call_order.append(tool_id)
elif et == "response.tool_call.completed":
item = getattr(event, "item", None)
tool_id = (
getattr(event, "id", None)
or getattr(event, "tool_call_id", None)
or (getattr(item, "id", None) if item is not None else None)
)
args_final = None
if item is not None:
args_final = getattr(item, "arguments", None) or getattr(item, "args", None)
if tool_id and args_final:
tool_calls_args[tool_id] = args_final if isinstance(args_final, str) else json.dumps(args_final)
if tool_id not in tool_call_order:
tool_call_order.append(tool_id)
# ignore other events for now
last_final = await stream.get_final_response()
# Usage/cost tracking after stream ends
input_tokens = last_final.usage.input_tokens if last_final and last_final.usage else 0
output_tokens = last_final.usage.output_tokens if last_final and last_final.usage else 0
cost = 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
)
# If there are tool calls, package them into aggregated text for tracing parity
if tool_call_order:
packaged_tool_calls: List[ToolCall] = []
for tool_id in tool_call_order:
name = tool_calls_name.get(tool_id) or ""
args_str = tool_calls_args.get(tool_id, "")
try:
args = json.loads(args_str) if isinstance(args_str, str) else args_str
except Exception:
logger.warning(f"Failed to parse tool call arguments for {tool_id}: {args_str}")
args = {}
packaged_tool_calls.append(ToolCall(name=name, args=args, id=tool_id))
# Move any text into trace thought
tracer_text = aggregated_text
aggregated_text = json.dumps([tc.__dict__ for tc in packaged_tool_calls])
# Save for trace below
if tracer_text:
tracer.setdefault("_responses_stream_text", tracer_text)
if is_none_or_empty(aggregated_text):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over Responses API. Retry if needed.")
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
# If tool-calls were present, include any streamed text in the trace thought
trace_payload = aggregated_text
if tracer.get("_responses_stream_text"):
thoughts = tracer.pop("_responses_stream_text")
trace_payload = thoughts
commit_conversation_trace(messages, trace_payload, tracer)
def get_structured_output_support(model_name: str, api_base_url: str = None) -> StructuredOutputSupport:
if model_name.startswith("deepseek-reasoner"):
return StructuredOutputSupport.NONE
@@ -412,6 +715,12 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) -
# Handle tool call and tool result message types
message_type = message.additional_kwargs.get("message_type")
if message_type == "tool_call":
if is_openai_api(api_base_url):
for part in message.content:
if "status" in part:
part.pop("status") # Drop unsupported tool call status field
formatted_messages.extend(message.content)
continue
# Convert tool_call to OpenAI function call format
content = []
for part in message.content:
@@ -450,14 +759,23 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) -
if not tool_call_id:
logger.warning(f"Dropping tool result without valid tool_call_id: {part.get('name')}")
continue
formatted_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"name": part.get("name"),
"content": part.get("content"),
}
)
if is_openai_api(api_base_url):
formatted_messages.append(
{
"type": "function_call_output",
"call_id": tool_call_id,
"output": part.get("content"),
}
)
else:
formatted_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"name": part.get("name"),
"content": part.get("content"),
}
)
continue
if isinstance(message.content, list) and not is_openai_api(api_base_url):
assistant_texts = []
@@ -489,6 +807,11 @@ def format_message_for_api(raw_messages: List[ChatMessage], api_base_url: str) -
message.content.remove(part)
elif part["type"] == "image_url" and not part.get("image_url"):
message.content.remove(part)
# OpenAI models use the Responses API which uses slightly different content types
if part["type"] == "text":
part["type"] = "output_text" if message.role == "assistant" else "input_text"
if part["type"] == "image":
part["type"] = "output_image" if message.role == "assistant" else "input_image"
# If no valid content parts left, remove the message
if is_none_or_empty(message.content):
messages.remove(message)
@@ -852,20 +1175,32 @@ def add_qwen_no_think_tag(formatted_messages: List[dict]) -> None:
break
def to_openai_tools(tools: List[ToolDefinition]) -> List[Dict] | None:
def to_openai_tools(tools: List[ToolDefinition], use_responses_api: bool) -> List[Dict] | None:
"Transform tool definitions from standard format to OpenAI format."
openai_tools = [
{
"type": "function",
"function": {
if use_responses_api:
openai_tools = [
{
"type": "function",
"name": tool.name,
"description": tool.description,
"parameters": clean_response_schema(tool.schema),
"strict": True,
},
}
for tool in tools
]
}
for tool in tools
]
else:
openai_tools = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": clean_response_schema(tool.schema),
"strict": True,
},
}
for tool in tools
]
return openai_tools or None

View File

@@ -68,6 +68,9 @@ model_to_prompt_size = {
"o3": 60000,
"o3-pro": 30000,
"o4-mini": 90000,
"gpt-5-2025-08-07": 120000,
"gpt-5-mini-2025-08-07": 120000,
"gpt-5-nano-2025-08-07": 120000,
# Google Models
"gemini-2.5-flash": 120000,
"gemini-2.5-pro": 60000,

View File

@@ -40,6 +40,9 @@ model_to_cost: Dict[str, Dict[str, float]] = {
"o3": {"input": 2.0, "output": 8.00},
"o3-pro": {"input": 20.0, "output": 80.00},
"o4-mini": {"input": 1.10, "output": 4.40},
"gpt-5-2025-08-07": {"input": 1.25, "output": 10.00, "cache_read": 0.125},
"gpt-5-mini-2025-08-07": {"input": 0.25, "output": 2.00, "cache_read": 0.025},
"gpt-5-nano-2025-08-07": {"input": 0.05, "output": 0.40, "cache_read": 0.005},
# Gemini Pricing: https://ai.google.dev/pricing
"gemini-1.5-flash": {"input": 0.075, "output": 0.30},
"gemini-1.5-flash-002": {"input": 0.075, "output": 0.30},