Enable non-streaming response via openai api to support o3 models

This commit is contained in:
Debanjum
2025-06-10 19:02:29 -07:00
parent 5110a06085
commit 753972997f

View File

@@ -14,6 +14,7 @@ from openai.lib.streaming.chat import (
ChatCompletionStreamEvent,
ContentDeltaEvent,
)
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
Choice,
@@ -78,7 +79,11 @@ def completion_with_backoff(
client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client
stream = not is_non_streaming_model(model_name, api_base_url)
stream_processor = default_stream_processor
if stream:
model_kwargs["stream_options"] = {"include_usage": True}
formatted_messages = format_message_for_api(messages, api_base_url)
# Tune reasoning models arguments
@@ -109,23 +114,33 @@ def completion_with_backoff(
add_qwen_no_think_tag(formatted_messages)
read_timeout = 300 if is_local_api(api_base_url) else 60
model_kwargs["stream_options"] = {"include_usage": True}
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
aggregated_response = ""
with client.beta.chat.completions.stream(
messages=formatted_messages, # type: ignore
model=model_name,
temperature=temperature,
timeout=httpx.Timeout(30, read=read_timeout),
**model_kwargs,
) as chat:
for chunk in stream_processor(chat):
if chunk.type == "content.delta":
aggregated_response += chunk.delta
elif chunk.type == "thought.delta":
pass
if stream:
with client.beta.chat.completions.stream(
messages=formatted_messages, # type: ignore
model=model_name,
temperature=temperature,
timeout=httpx.Timeout(30, read=read_timeout),
**model_kwargs,
) as chat:
for chunk in stream_processor(chat):
if chunk.type == "content.delta":
aggregated_response += chunk.delta
elif chunk.type == "thought.delta":
pass
else:
# Non-streaming chat completion
chunk = client.beta.chat.completions.parse(
messages=formatted_messages, # type: ignore
model=model_name,
temperature=temperature,
timeout=httpx.Timeout(30, read=read_timeout),
**model_kwargs,
)
aggregated_response = chunk.choices[0].message.content
# Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
@@ -182,7 +197,11 @@ async def chat_completion_with_backoff(
client = get_openai_async_client(openai_api_key, api_base_url)
openai_async_clients[client_key] = client
stream = not is_non_streaming_model(model_name, api_base_url)
stream_processor = adefault_stream_processor
if stream:
model_kwargs["stream_options"] = {"include_usage": True}
formatted_messages = format_message_for_api(messages, api_base_url)
# Configure thinking for openai reasoning models
@@ -228,9 +247,7 @@ async def chat_completion_with_backoff(
if not deepthought:
add_qwen_no_think_tag(formatted_messages)
stream = True
read_timeout = 300 if is_local_api(api_base_url) else 60
model_kwargs["stream_options"] = {"include_usage": True}
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
@@ -238,7 +255,7 @@ async def chat_completion_with_backoff(
final_chunk = None
response_started = False
start_time = perf_counter()
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
response: openai.AsyncStream[ChatCompletionChunk] | ChatCompletion = await client.chat.completions.create(
messages=formatted_messages, # type: ignore
model=model_name,
stream=stream,
@@ -246,26 +263,34 @@ async def chat_completion_with_backoff(
timeout=httpx.Timeout(30, read=read_timeout),
**model_kwargs,
)
async for chunk in stream_processor(chat_stream):
# Log the time taken to start response
if not response_started:
response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Skip empty chunks
if len(chunk.choices) == 0:
continue
# Handle streamed response chunk
response_chunk: ResponseWithThought = None
response_delta = chunk.choices[0].delta
if response_delta.content:
response_chunk = ResponseWithThought(response=response_delta.content)
aggregated_response += response_chunk.response
elif response_delta.thought:
response_chunk = ResponseWithThought(thought=response_delta.thought)
if response_chunk:
yield response_chunk
if not stream:
# If not streaming, we can return the response directly
if len(response.choices) == 0 or not response.choices[0].message:
raise ValueError("No response by model.")
aggregated_response = response.choices[0].message.content
final_chunk = response
yield ResponseWithThought(response=aggregated_response)
else:
async for chunk in stream_processor(response):
# Log the time taken to start response
if not response_started:
response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Skip empty chunks
if len(chunk.choices) == 0:
continue
# Handle streamed response chunk
response_chunk: ResponseWithThought = None
response_delta = chunk.choices[0].delta
if response_delta.content:
response_chunk = ResponseWithThought(response=response_delta.content)
aggregated_response += response_chunk.response
elif response_delta.thought:
response_chunk = ResponseWithThought(thought=response_delta.thought)
if response_chunk:
yield response_chunk
# Calculate cost of chat after stream finishes
input_tokens, output_tokens, cost = 0, 0, 0
@@ -354,6 +379,14 @@ def is_openai_reasoning_model(model_name: str, api_base_url: str = None) -> bool
return model_name.startswith("o") and is_openai_api(api_base_url)
def is_non_streaming_model(model_name: str, api_base_url: str = None) -> bool:
"""
Check if model response should not be streamed.
"""
# Some OpenAI models requires biometrics to stream. Avoid streaming their responses.
return model_name in ["o3", "o3-pro"] and is_openai_api(api_base_url)
def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
"""
Check if the model is a Twitter reasoning model