From 7b9f2c21c76ff2ad977bd0146e1884ca87704850 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 2 May 2025 10:29:36 -0600 Subject: [PATCH] Parse thoughts from thinking models served via OpenAI compatible API OpenAI API doesn't support thoughts via chat completion by default. But there are thinking models served via OpenAI compatible APIs like deepseek and qwen3. Add stream handlers and modified response types that can contain thoughts as well apart from content returned by a model. This can be used to instantiate stream handlers for different model types like deepseek, qwen3 etc served over an OpenAI compatible API. --- .../processor/conversation/openai/utils.py | 74 +++++++++++++++++-- 1 file changed, 68 insertions(+), 6 deletions(-) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 7fab44aa..019c785a 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -1,12 +1,21 @@ import logging import os +from functools import partial from time import perf_counter -from typing import AsyncGenerator, Dict, List +from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union from urllib.parse import urlparse import openai -from openai.types.chat.chat_completion import ChatCompletion -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.lib.streaming.chat import ( + ChatCompletionStream, + ChatCompletionStreamEvent, + ContentDeltaEvent, +) +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, +) from tenacity import ( before_sleep_log, retry, @@ -59,6 +68,7 @@ def completion_with_backoff( client = get_openai_client(openai_api_key, api_base_url) openai_clients[client_key] = client + stream_processor = default_stream_processor formatted_messages = [{"role": message.role, "content": message.content} for message in messages] # Tune reasoning models arguments @@ -82,12 +92,14 @@ def completion_with_backoff( timeout=20, **model_kwargs, ) as chat: - for chunk in chat: + for chunk in stream_processor(chat): if chunk.type == "error": logger.error(f"Openai api response error: {chunk.error}", exc_info=True) continue elif chunk.type == "content.delta": aggregated_response += chunk.delta + elif chunk.type == "thought.delta": + pass # Calculate cost of chat input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 @@ -124,7 +136,7 @@ def completion_with_backoff( ) async def chat_completion_with_backoff( messages, - model_name, + model_name: str, temperature, openai_api_key=None, api_base_url=None, @@ -139,6 +151,7 @@ async def chat_completion_with_backoff( client = get_openai_async_client(openai_api_key, api_base_url) openai_async_clients[client_key] = client + stream_processor = adefault_stream_processor formatted_messages = [{"role": message.role, "content": message.content} for message in messages] # Configure thinking for openai reasoning models @@ -193,7 +206,7 @@ async def chat_completion_with_backoff( timeout=20, **model_kwargs, ) - async for chunk in chat_stream: + async for chunk in stream_processor(chat_stream): # Log the time taken to start response if final_chunk is None: logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") @@ -264,3 +277,52 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo and api_base_url is not None and api_base_url.startswith("https://api.x.ai/v1") ) + + +class ThoughtDeltaEvent(ContentDeltaEvent): + """ + Chat completion chunk with thoughts, reasoning support. + """ + + type: Literal["thought.delta"] + """The thought or reasoning generated by the model.""" + + +ChatCompletionStreamWithThoughtEvent = Union[ChatCompletionStreamEvent, ThoughtDeltaEvent] + + +class ChoiceDeltaWithThoughts(ChoiceDelta): + """ + Chat completion chunk with thoughts, reasoning support. + """ + + thought: Optional[str] = None + """The thought or reasoning generated by the model.""" + + +class ChoiceWithThoughts(Choice): + delta: ChoiceDeltaWithThoughts + + +class ChatCompletionWithThoughtsChunk(ChatCompletionChunk): + choices: List[ChoiceWithThoughts] # Override the choices type + + +def default_stream_processor( + chat_stream: ChatCompletionStream, +) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]: + """ + Async generator to cast and return chunks from the standard openai chat completions stream. + """ + for chunk in chat_stream: + yield chunk + + +async def adefault_stream_processor( + chat_stream: openai.AsyncStream[ChatCompletionChunk], +) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]: + """ + Async generator to cast and return chunks from the standard openai chat completions stream. + """ + async for chunk in chat_stream: + yield ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump())