Parse thoughts from thinking models served via OpenAI compatible API

OpenAI API doesn't support thoughts via chat completion by default.
But there are thinking models served via OpenAI compatible APIs like
deepseek and qwen3.

Add stream handlers and modified response types that can contain
thoughts as well apart from content returned by a model.

This can be used to instantiate stream handlers for different model
types like deepseek, qwen3 etc served over an OpenAI compatible API.
This commit is contained in:
Debanjum
2025-05-02 10:29:36 -06:00
parent 6843db1647
commit 7b9f2c21c7

View File

@@ -1,12 +1,21 @@
import logging import logging
import os import os
from functools import partial
from time import perf_counter from time import perf_counter
from typing import AsyncGenerator, Dict, List from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import openai import openai
from openai.types.chat.chat_completion import ChatCompletion from openai.lib.streaming.chat import (
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk ChatCompletionStream,
ChatCompletionStreamEvent,
ContentDeltaEvent,
)
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
Choice,
ChoiceDelta,
)
from tenacity import ( from tenacity import (
before_sleep_log, before_sleep_log,
retry, retry,
@@ -59,6 +68,7 @@ def completion_with_backoff(
client = get_openai_client(openai_api_key, api_base_url) client = get_openai_client(openai_api_key, api_base_url)
openai_clients[client_key] = client openai_clients[client_key] = client
stream_processor = default_stream_processor
formatted_messages = [{"role": message.role, "content": message.content} for message in messages] formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
# Tune reasoning models arguments # Tune reasoning models arguments
@@ -82,12 +92,14 @@ def completion_with_backoff(
timeout=20, timeout=20,
**model_kwargs, **model_kwargs,
) as chat: ) as chat:
for chunk in chat: for chunk in stream_processor(chat):
if chunk.type == "error": if chunk.type == "error":
logger.error(f"Openai api response error: {chunk.error}", exc_info=True) logger.error(f"Openai api response error: {chunk.error}", exc_info=True)
continue continue
elif chunk.type == "content.delta": elif chunk.type == "content.delta":
aggregated_response += chunk.delta aggregated_response += chunk.delta
elif chunk.type == "thought.delta":
pass
# Calculate cost of chat # Calculate cost of chat
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
@@ -124,7 +136,7 @@ def completion_with_backoff(
) )
async def chat_completion_with_backoff( async def chat_completion_with_backoff(
messages, messages,
model_name, model_name: str,
temperature, temperature,
openai_api_key=None, openai_api_key=None,
api_base_url=None, api_base_url=None,
@@ -139,6 +151,7 @@ async def chat_completion_with_backoff(
client = get_openai_async_client(openai_api_key, api_base_url) client = get_openai_async_client(openai_api_key, api_base_url)
openai_async_clients[client_key] = client openai_async_clients[client_key] = client
stream_processor = adefault_stream_processor
formatted_messages = [{"role": message.role, "content": message.content} for message in messages] formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
# Configure thinking for openai reasoning models # Configure thinking for openai reasoning models
@@ -193,7 +206,7 @@ async def chat_completion_with_backoff(
timeout=20, timeout=20,
**model_kwargs, **model_kwargs,
) )
async for chunk in chat_stream: async for chunk in stream_processor(chat_stream):
# Log the time taken to start response # Log the time taken to start response
if final_chunk is None: if final_chunk is None:
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
@@ -264,3 +277,52 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo
and api_base_url is not None and api_base_url is not None
and api_base_url.startswith("https://api.x.ai/v1") and api_base_url.startswith("https://api.x.ai/v1")
) )
class ThoughtDeltaEvent(ContentDeltaEvent):
"""
Chat completion chunk with thoughts, reasoning support.
"""
type: Literal["thought.delta"]
"""The thought or reasoning generated by the model."""
ChatCompletionStreamWithThoughtEvent = Union[ChatCompletionStreamEvent, ThoughtDeltaEvent]
class ChoiceDeltaWithThoughts(ChoiceDelta):
"""
Chat completion chunk with thoughts, reasoning support.
"""
thought: Optional[str] = None
"""The thought or reasoning generated by the model."""
class ChoiceWithThoughts(Choice):
delta: ChoiceDeltaWithThoughts
class ChatCompletionWithThoughtsChunk(ChatCompletionChunk):
choices: List[ChoiceWithThoughts] # Override the choices type
def default_stream_processor(
chat_stream: ChatCompletionStream,
) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]:
"""
Async generator to cast and return chunks from the standard openai chat completions stream.
"""
for chunk in chat_stream:
yield chunk
async def adefault_stream_processor(
chat_stream: openai.AsyncStream[ChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]:
"""
Async generator to cast and return chunks from the standard openai chat completions stream.
"""
async for chunk in chat_stream:
yield ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump())