mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Enable non-streaming response via openai api to support o3 models
This commit is contained in:
@@ -14,6 +14,7 @@ from openai.lib.streaming.chat import (
|
|||||||
ChatCompletionStreamEvent,
|
ChatCompletionStreamEvent,
|
||||||
ContentDeltaEvent,
|
ContentDeltaEvent,
|
||||||
)
|
)
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
from openai.types.chat.chat_completion_chunk import (
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
Choice,
|
Choice,
|
||||||
@@ -78,7 +79,11 @@ def completion_with_backoff(
|
|||||||
client = get_openai_client(openai_api_key, api_base_url)
|
client = get_openai_client(openai_api_key, api_base_url)
|
||||||
openai_clients[client_key] = client
|
openai_clients[client_key] = client
|
||||||
|
|
||||||
|
stream = not is_non_streaming_model(model_name, api_base_url)
|
||||||
stream_processor = default_stream_processor
|
stream_processor = default_stream_processor
|
||||||
|
if stream:
|
||||||
|
model_kwargs["stream_options"] = {"include_usage": True}
|
||||||
|
|
||||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||||
|
|
||||||
# Tune reasoning models arguments
|
# Tune reasoning models arguments
|
||||||
@@ -109,23 +114,33 @@ def completion_with_backoff(
|
|||||||
add_qwen_no_think_tag(formatted_messages)
|
add_qwen_no_think_tag(formatted_messages)
|
||||||
|
|
||||||
read_timeout = 300 if is_local_api(api_base_url) else 60
|
read_timeout = 300 if is_local_api(api_base_url) else 60
|
||||||
model_kwargs["stream_options"] = {"include_usage": True}
|
|
||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
aggregated_response = ""
|
aggregated_response = ""
|
||||||
with client.beta.chat.completions.stream(
|
if stream:
|
||||||
messages=formatted_messages, # type: ignore
|
with client.beta.chat.completions.stream(
|
||||||
model=model_name,
|
messages=formatted_messages, # type: ignore
|
||||||
temperature=temperature,
|
model=model_name,
|
||||||
timeout=httpx.Timeout(30, read=read_timeout),
|
temperature=temperature,
|
||||||
**model_kwargs,
|
timeout=httpx.Timeout(30, read=read_timeout),
|
||||||
) as chat:
|
**model_kwargs,
|
||||||
for chunk in stream_processor(chat):
|
) as chat:
|
||||||
if chunk.type == "content.delta":
|
for chunk in stream_processor(chat):
|
||||||
aggregated_response += chunk.delta
|
if chunk.type == "content.delta":
|
||||||
elif chunk.type == "thought.delta":
|
aggregated_response += chunk.delta
|
||||||
pass
|
elif chunk.type == "thought.delta":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Non-streaming chat completion
|
||||||
|
chunk = client.beta.chat.completions.parse(
|
||||||
|
messages=formatted_messages, # type: ignore
|
||||||
|
model=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
timeout=httpx.Timeout(30, read=read_timeout),
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
aggregated_response = chunk.choices[0].message.content
|
||||||
|
|
||||||
# Calculate cost of chat
|
# Calculate cost of chat
|
||||||
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
|
||||||
@@ -182,7 +197,11 @@ async def chat_completion_with_backoff(
|
|||||||
client = get_openai_async_client(openai_api_key, api_base_url)
|
client = get_openai_async_client(openai_api_key, api_base_url)
|
||||||
openai_async_clients[client_key] = client
|
openai_async_clients[client_key] = client
|
||||||
|
|
||||||
|
stream = not is_non_streaming_model(model_name, api_base_url)
|
||||||
stream_processor = adefault_stream_processor
|
stream_processor = adefault_stream_processor
|
||||||
|
if stream:
|
||||||
|
model_kwargs["stream_options"] = {"include_usage": True}
|
||||||
|
|
||||||
formatted_messages = format_message_for_api(messages, api_base_url)
|
formatted_messages = format_message_for_api(messages, api_base_url)
|
||||||
|
|
||||||
# Configure thinking for openai reasoning models
|
# Configure thinking for openai reasoning models
|
||||||
@@ -228,9 +247,7 @@ async def chat_completion_with_backoff(
|
|||||||
if not deepthought:
|
if not deepthought:
|
||||||
add_qwen_no_think_tag(formatted_messages)
|
add_qwen_no_think_tag(formatted_messages)
|
||||||
|
|
||||||
stream = True
|
|
||||||
read_timeout = 300 if is_local_api(api_base_url) else 60
|
read_timeout = 300 if is_local_api(api_base_url) else 60
|
||||||
model_kwargs["stream_options"] = {"include_usage": True}
|
|
||||||
if os.getenv("KHOJ_LLM_SEED"):
|
if os.getenv("KHOJ_LLM_SEED"):
|
||||||
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
|
||||||
|
|
||||||
@@ -238,7 +255,7 @@ async def chat_completion_with_backoff(
|
|||||||
final_chunk = None
|
final_chunk = None
|
||||||
response_started = False
|
response_started = False
|
||||||
start_time = perf_counter()
|
start_time = perf_counter()
|
||||||
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
response: openai.AsyncStream[ChatCompletionChunk] | ChatCompletion = await client.chat.completions.create(
|
||||||
messages=formatted_messages, # type: ignore
|
messages=formatted_messages, # type: ignore
|
||||||
model=model_name,
|
model=model_name,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
@@ -246,26 +263,34 @@ async def chat_completion_with_backoff(
|
|||||||
timeout=httpx.Timeout(30, read=read_timeout),
|
timeout=httpx.Timeout(30, read=read_timeout),
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
async for chunk in stream_processor(chat_stream):
|
if not stream:
|
||||||
# Log the time taken to start response
|
# If not streaming, we can return the response directly
|
||||||
if not response_started:
|
if len(response.choices) == 0 or not response.choices[0].message:
|
||||||
response_started = True
|
raise ValueError("No response by model.")
|
||||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
aggregated_response = response.choices[0].message.content
|
||||||
# Keep track of the last chunk for usage data
|
final_chunk = response
|
||||||
final_chunk = chunk
|
yield ResponseWithThought(response=aggregated_response)
|
||||||
# Skip empty chunks
|
else:
|
||||||
if len(chunk.choices) == 0:
|
async for chunk in stream_processor(response):
|
||||||
continue
|
# Log the time taken to start response
|
||||||
# Handle streamed response chunk
|
if not response_started:
|
||||||
response_chunk: ResponseWithThought = None
|
response_started = True
|
||||||
response_delta = chunk.choices[0].delta
|
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||||
if response_delta.content:
|
# Keep track of the last chunk for usage data
|
||||||
response_chunk = ResponseWithThought(response=response_delta.content)
|
final_chunk = chunk
|
||||||
aggregated_response += response_chunk.response
|
# Skip empty chunks
|
||||||
elif response_delta.thought:
|
if len(chunk.choices) == 0:
|
||||||
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
continue
|
||||||
if response_chunk:
|
# Handle streamed response chunk
|
||||||
yield response_chunk
|
response_chunk: ResponseWithThought = None
|
||||||
|
response_delta = chunk.choices[0].delta
|
||||||
|
if response_delta.content:
|
||||||
|
response_chunk = ResponseWithThought(response=response_delta.content)
|
||||||
|
aggregated_response += response_chunk.response
|
||||||
|
elif response_delta.thought:
|
||||||
|
response_chunk = ResponseWithThought(thought=response_delta.thought)
|
||||||
|
if response_chunk:
|
||||||
|
yield response_chunk
|
||||||
|
|
||||||
# Calculate cost of chat after stream finishes
|
# Calculate cost of chat after stream finishes
|
||||||
input_tokens, output_tokens, cost = 0, 0, 0
|
input_tokens, output_tokens, cost = 0, 0, 0
|
||||||
@@ -354,6 +379,14 @@ def is_openai_reasoning_model(model_name: str, api_base_url: str = None) -> bool
|
|||||||
return model_name.startswith("o") and is_openai_api(api_base_url)
|
return model_name.startswith("o") and is_openai_api(api_base_url)
|
||||||
|
|
||||||
|
|
||||||
|
def is_non_streaming_model(model_name: str, api_base_url: str = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if model response should not be streamed.
|
||||||
|
"""
|
||||||
|
# Some OpenAI models requires biometrics to stream. Avoid streaming their responses.
|
||||||
|
return model_name in ["o3", "o3-pro"] and is_openai_api(api_base_url)
|
||||||
|
|
||||||
|
|
||||||
def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
|
def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the model is a Twitter reasoning model
|
Check if the model is a Twitter reasoning model
|
||||||
|
|||||||
Reference in New Issue
Block a user