mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Enable non-streaming response via openai api to support o3 models
This commit is contained in:
@@ -14,6 +14,7 @@ from openai.lib.streaming.chat import (
|
||||
ChatCompletionStreamEvent,
|
||||
ContentDeltaEvent,
|
||||
)
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk,
|
||||
Choice,
|
||||
@@ -78,7 +79,11 @@ def completion_with_backoff(
|
||||
client = get_openai_client(openai_api_key, api_base_url)
|
||||
openai_clients[client_key] = client
|
||||
|
||||
stream = not is_non_streaming_model(model_name, api_base_url)
|
||||
stream_processor = default_stream_processor
|
||||
if stream:
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
|
||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||
|
||||
# Tune reasoning models arguments
|
||||
@@ -109,23 +114,33 @@ def completion_with_backoff(
|
||||
add_qwen_no_think_tag(formatted_messages)
|
||||
|
||||
read_timeout = 300 if is_local_api(api_base_url) else 60
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
aggregated_response = ""
|
||||
with client.beta.chat.completions.stream(
|
||||
messages=formatted_messages, # type: ignore
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
timeout=httpx.Timeout(30, read=read_timeout),
|
||||
**model_kwargs,
|
||||
) as chat:
|
||||
for chunk in stream_processor(chat):
|
||||
if chunk.type == "content.delta":
|
||||
aggregated_response += chunk.delta
|
||||
elif chunk.type == "thought.delta":
|
||||
pass
|
||||
if stream:
|
||||
with client.beta.chat.completions.stream(
|
||||
messages=formatted_messages, # type: ignore
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
timeout=httpx.Timeout(30, read=read_timeout),
|
||||
**model_kwargs,
|
||||
) as chat:
|
||||
for chunk in stream_processor(chat):
|
||||
if chunk.type == "content.delta":
|
||||
aggregated_response += chunk.delta
|
||||
elif chunk.type == "thought.delta":
|
||||
pass
|
||||
else:
|
||||
# Non-streaming chat completion
|
||||
chunk = client.beta.chat.completions.parse(
|
||||
messages=formatted_messages, # type: ignore
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
timeout=httpx.Timeout(30, read=read_timeout),
|
||||
**model_kwargs,
|
||||
)
|
||||
aggregated_response = chunk.choices[0].message.content
|
||||
|
||||
# Calculate cost of chat
|
||||
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||
@@ -182,7 +197,11 @@ async def chat_completion_with_backoff(
|
||||
client = get_openai_async_client(openai_api_key, api_base_url)
|
||||
openai_async_clients[client_key] = client
|
||||
|
||||
stream = not is_non_streaming_model(model_name, api_base_url)
|
||||
stream_processor = adefault_stream_processor
|
||||
if stream:
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
|
||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||
|
||||
# Configure thinking for openai reasoning models
|
||||
@@ -228,9 +247,7 @@ async def chat_completion_with_backoff(
|
||||
if not deepthought:
|
||||
add_qwen_no_think_tag(formatted_messages)
|
||||
|
||||
stream = True
|
||||
read_timeout = 300 if is_local_api(api_base_url) else 60
|
||||
model_kwargs["stream_options"] = {"include_usage": True}
|
||||
if os.getenv("KHOJ_LLM_SEED"):
|
||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||
|
||||
@@ -238,7 +255,7 @@ async def chat_completion_with_backoff(
|
||||
final_chunk = None
|
||||
response_started = False
|
||||
start_time = perf_counter()
|
||||
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
response: openai.AsyncStream[ChatCompletionChunk] | ChatCompletion = await client.chat.completions.create(
|
||||
messages=formatted_messages, # type: ignore
|
||||
model=model_name,
|
||||
stream=stream,
|
||||
@@ -246,26 +263,34 @@ async def chat_completion_with_backoff(
|
||||
timeout=httpx.Timeout(30, read=read_timeout),
|
||||
**model_kwargs,
|
||||
)
|
||||
async for chunk in stream_processor(chat_stream):
|
||||
# Log the time taken to start response
|
||||
if not response_started:
|
||||
response_started = True
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Keep track of the last chunk for usage data
|
||||
final_chunk = chunk
|
||||
# Skip empty chunks
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
# Handle streamed response chunk
|
||||
response_chunk: ResponseWithThought = None
|
||||
response_delta = chunk.choices[0].delta
|
||||
if response_delta.content:
|
||||
response_chunk = ResponseWithThought(response=response_delta.content)
|
||||
aggregated_response += response_chunk.response
|
||||
elif response_delta.thought:
|
||||
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
||||
if response_chunk:
|
||||
yield response_chunk
|
||||
if not stream:
|
||||
# If not streaming, we can return the response directly
|
||||
if len(response.choices) == 0 or not response.choices[0].message:
|
||||
raise ValueError("No response by model.")
|
||||
aggregated_response = response.choices[0].message.content
|
||||
final_chunk = response
|
||||
yield ResponseWithThought(response=aggregated_response)
|
||||
else:
|
||||
async for chunk in stream_processor(response):
|
||||
# Log the time taken to start response
|
||||
if not response_started:
|
||||
response_started = True
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Keep track of the last chunk for usage data
|
||||
final_chunk = chunk
|
||||
# Skip empty chunks
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
# Handle streamed response chunk
|
||||
response_chunk: ResponseWithThought = None
|
||||
response_delta = chunk.choices[0].delta
|
||||
if response_delta.content:
|
||||
response_chunk = ResponseWithThought(response=response_delta.content)
|
||||
aggregated_response += response_chunk.response
|
||||
elif response_delta.thought:
|
||||
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
||||
if response_chunk:
|
||||
yield response_chunk
|
||||
|
||||
# Calculate cost of chat after stream finishes
|
||||
input_tokens, output_tokens, cost = 0, 0, 0
|
||||
@@ -354,6 +379,14 @@ def is_openai_reasoning_model(model_name: str, api_base_url: str = None) -> bool
|
||||
return model_name.startswith("o") and is_openai_api(api_base_url)
|
||||
|
||||
|
||||
def is_non_streaming_model(model_name: str, api_base_url: str = None) -> bool:
|
||||
"""
|
||||
Check if model response should not be streamed.
|
||||
"""
|
||||
# Some OpenAI models requires biometrics to stream. Avoid streaming their responses.
|
||||
return model_name in ["o3", "o3-pro"] and is_openai_api(api_base_url)
|
||||
|
||||
|
||||
def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
|
||||
"""
|
||||
Check if the model is a Twitter reasoning model
|
||||
|
||||
Reference in New Issue
Block a user