diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index fd3d0fd7..20e632d0 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -14,6 +14,7 @@ from openai.lib.streaming.chat import ( ChatCompletionStreamEvent, ContentDeltaEvent, ) +from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk, Choice, @@ -78,7 +79,11 @@ def completion_with_backoff( client = get_openai_client(openai_api_key, api_base_url) openai_clients[client_key] = client + stream = not is_non_streaming_model(model_name, api_base_url) stream_processor = default_stream_processor + if stream: + model_kwargs["stream_options"] = {"include_usage": True} + formatted_messages = format_message_for_api(messages, api_base_url) # Tune reasoning models arguments @@ -109,23 +114,33 @@ def completion_with_backoff( add_qwen_no_think_tag(formatted_messages) read_timeout = 300 if is_local_api(api_base_url) else 60 - model_kwargs["stream_options"] = {"include_usage": True} if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) aggregated_response = "" - with client.beta.chat.completions.stream( - messages=formatted_messages, # type: ignore - model=model_name, - temperature=temperature, - timeout=httpx.Timeout(30, read=read_timeout), - **model_kwargs, - ) as chat: - for chunk in stream_processor(chat): - if chunk.type == "content.delta": - aggregated_response += chunk.delta - elif chunk.type == "thought.delta": - pass + if stream: + with client.beta.chat.completions.stream( + messages=formatted_messages, # type: ignore + model=model_name, + temperature=temperature, + timeout=httpx.Timeout(30, read=read_timeout), + **model_kwargs, + ) as chat: + for chunk in stream_processor(chat): + if chunk.type == "content.delta": + aggregated_response += chunk.delta + elif chunk.type == "thought.delta": + pass + else: + # Non-streaming chat completion + chunk = client.beta.chat.completions.parse( + messages=formatted_messages, # type: ignore + model=model_name, + temperature=temperature, + timeout=httpx.Timeout(30, read=read_timeout), + **model_kwargs, + ) + aggregated_response = chunk.choices[0].message.content # Calculate cost of chat input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0 @@ -182,7 +197,11 @@ async def chat_completion_with_backoff( client = get_openai_async_client(openai_api_key, api_base_url) openai_async_clients[client_key] = client + stream = not is_non_streaming_model(model_name, api_base_url) stream_processor = adefault_stream_processor + if stream: + model_kwargs["stream_options"] = {"include_usage": True} + formatted_messages = format_message_for_api(messages, api_base_url) # Configure thinking for openai reasoning models @@ -228,9 +247,7 @@ async def chat_completion_with_backoff( if not deepthought: add_qwen_no_think_tag(formatted_messages) - stream = True read_timeout = 300 if is_local_api(api_base_url) else 60 - model_kwargs["stream_options"] = {"include_usage": True} if os.getenv("KHOJ_LLM_SEED"): model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) @@ -238,7 +255,7 @@ async def chat_completion_with_backoff( final_chunk = None response_started = False start_time = perf_counter() - chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( + response: openai.AsyncStream[ChatCompletionChunk] | ChatCompletion = await client.chat.completions.create( messages=formatted_messages, # type: ignore model=model_name, stream=stream, @@ -246,26 +263,34 @@ async def chat_completion_with_backoff( timeout=httpx.Timeout(30, read=read_timeout), **model_kwargs, ) - async for chunk in stream_processor(chat_stream): - # Log the time taken to start response - if not response_started: - response_started = True - logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") - # Keep track of the last chunk for usage data - final_chunk = chunk - # Skip empty chunks - if len(chunk.choices) == 0: - continue - # Handle streamed response chunk - response_chunk: ResponseWithThought = None - response_delta = chunk.choices[0].delta - if response_delta.content: - response_chunk = ResponseWithThought(response=response_delta.content) - aggregated_response += response_chunk.response - elif response_delta.thought: - response_chunk = ResponseWithThought(thought=response_delta.thought) - if response_chunk: - yield response_chunk + if not stream: + # If not streaming, we can return the response directly + if len(response.choices) == 0 or not response.choices[0].message: + raise ValueError("No response by model.") + aggregated_response = response.choices[0].message.content + final_chunk = response + yield ResponseWithThought(response=aggregated_response) + else: + async for chunk in stream_processor(response): + # Log the time taken to start response + if not response_started: + response_started = True + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Keep track of the last chunk for usage data + final_chunk = chunk + # Skip empty chunks + if len(chunk.choices) == 0: + continue + # Handle streamed response chunk + response_chunk: ResponseWithThought = None + response_delta = chunk.choices[0].delta + if response_delta.content: + response_chunk = ResponseWithThought(response=response_delta.content) + aggregated_response += response_chunk.response + elif response_delta.thought: + response_chunk = ResponseWithThought(thought=response_delta.thought) + if response_chunk: + yield response_chunk # Calculate cost of chat after stream finishes input_tokens, output_tokens, cost = 0, 0, 0 @@ -354,6 +379,14 @@ def is_openai_reasoning_model(model_name: str, api_base_url: str = None) -> bool return model_name.startswith("o") and is_openai_api(api_base_url) +def is_non_streaming_model(model_name: str, api_base_url: str = None) -> bool: + """ + Check if model response should not be streamed. + """ + # Some OpenAI models requires biometrics to stream. Avoid streaming their responses. + return model_name in ["o3", "o3-pro"] and is_openai_api(api_base_url) + + def is_twitter_reasoning_model(model_name: str, api_base_url: str = None) -> bool: """ Check if the model is a Twitter reasoning model