diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 7fab44aa..019c785a 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -1,12 +1,21 @@ import logging import os +from functools import partial from time import perf_counter -from typing import AsyncGenerator, Dict, List +from typing import AsyncGenerator, Dict, Generator, List, Literal, Optional, Union from urllib.parse import urlparse import openai -from openai.types.chat.chat_completion import ChatCompletion -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.lib.streaming.chat import ( + ChatCompletionStream, + ChatCompletionStreamEvent, + ContentDeltaEvent, +) +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, +) from tenacity import ( before_sleep_log, retry, @@ -59,6 +68,7 @@ def completion_with_backoff( client = get_openai_client(openai_api_key, api_base_url) openai_clients[client_key] = client + stream_processor = default_stream_processor formatted_messages = [{"role": message.role, "content": message.content} for message in messages] # Tune reasoning models arguments @@ -82,12 +92,14 @@ def completion_with_backoff( timeout=20, **model_kwargs, ) as chat: - for chunk in chat: + for chunk in stream_processor(chat): if chunk.type == "error": logger.error(f"Openai api response error: {chunk.error}", exc_info=True) continue elif chunk.type == "content.delta": aggregated_response += chunk.delta + elif chunk.type == "thought.delta": + pass # Calculate cost of chat input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 @@ -124,7 +136,7 @@ def completion_with_backoff( ) async def chat_completion_with_backoff( messages, - model_name, + model_name: str, temperature, openai_api_key=None, api_base_url=None, @@ -139,6 +151,7 @@ async def chat_completion_with_backoff( client = get_openai_async_client(openai_api_key, api_base_url) openai_async_clients[client_key] = client + stream_processor = adefault_stream_processor formatted_messages = [{"role": message.role, "content": message.content} for message in messages] # Configure thinking for openai reasoning models @@ -193,7 +206,7 @@ async def chat_completion_with_backoff( timeout=20, **model_kwargs, ) - async for chunk in chat_stream: + async for chunk in stream_processor(chat_stream): # Log the time taken to start response if final_chunk is None: logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") @@ -264,3 +277,52 @@ def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> boo and api_base_url is not None and api_base_url.startswith("https://api.x.ai/v1") ) + + +class ThoughtDeltaEvent(ContentDeltaEvent): + """ + Chat completion chunk with thoughts, reasoning support. + """ + + type: Literal["thought.delta"] + """The thought or reasoning generated by the model.""" + + +ChatCompletionStreamWithThoughtEvent = Union[ChatCompletionStreamEvent, ThoughtDeltaEvent] + + +class ChoiceDeltaWithThoughts(ChoiceDelta): + """ + Chat completion chunk with thoughts, reasoning support. + """ + + thought: Optional[str] = None + """The thought or reasoning generated by the model.""" + + +class ChoiceWithThoughts(Choice): + delta: ChoiceDeltaWithThoughts + + +class ChatCompletionWithThoughtsChunk(ChatCompletionChunk): + choices: List[ChoiceWithThoughts] # Override the choices type + + +def default_stream_processor( + chat_stream: ChatCompletionStream, +) -> Generator[ChatCompletionStreamWithThoughtEvent, None, None]: + """ + Async generator to cast and return chunks from the standard openai chat completions stream. + """ + for chunk in chat_stream: + yield chunk + + +async def adefault_stream_processor( + chat_stream: openai.AsyncStream[ChatCompletionChunk], +) -> AsyncGenerator[ChatCompletionWithThoughtsChunk, None]: + """ + Async generator to cast and return chunks from the standard openai chat completions stream. + """ + async for chunk in chat_stream: + yield ChatCompletionWithThoughtsChunk.model_validate(chunk.model_dump())