diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 915c082b..7c56b8ac 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -1,6 +1,6 @@ import logging from time import perf_counter -from typing import Dict, List +from typing import AsyncGenerator, Dict, List import anthropic from langchain_core.messages.chat import ChatMessage @@ -100,6 +100,11 @@ def anthropic_completion_with_backoff( model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage") ) + # Validate the response. If empty, raise an error to retry. + if is_none_or_empty(aggregated_response): + logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.") + raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.") + # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature @@ -112,8 +117,8 @@ def anthropic_completion_with_backoff( @retry( wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(2), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=False, ) async def anthropic_chat_completion_with_backoff( messages: list[ChatMessage], @@ -126,75 +131,77 @@ async def anthropic_chat_completion_with_backoff( deepthought=False, model_kwargs=None, tracer={}, -): - try: - client = anthropic_async_clients.get(api_key) - if not client: - client = get_anthropic_async_client(api_key, api_base_url) - anthropic_async_clients[api_key] = client +) -> AsyncGenerator[ResponseWithThought, None]: + client = anthropic_async_clients.get(api_key) + if not client: + client = get_anthropic_async_client(api_key, api_base_url) + anthropic_async_clients[api_key] = client - model_kwargs = model_kwargs or dict() - max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC - if deepthought and model_name.startswith("claude-3-7"): - model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": MAX_REASONING_TOKENS_ANTHROPIC} - max_tokens += MAX_REASONING_TOKENS_ANTHROPIC - # Temperature control not supported when using extended thinking - temperature = 1.0 + model_kwargs = model_kwargs or dict() + max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC + if deepthought and model_name.startswith("claude-3-7"): + model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": MAX_REASONING_TOKENS_ANTHROPIC} + max_tokens += MAX_REASONING_TOKENS_ANTHROPIC + # Temperature control not supported when using extended thinking + temperature = 1.0 - formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) + formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) - aggregated_response = "" - response_started = False - final_message = None - start_time = perf_counter() - async with client.messages.stream( - messages=formatted_messages, - model=model_name, # type: ignore - temperature=temperature, - system=system_prompt, - timeout=20, - max_tokens=max_tokens, - **model_kwargs, - ) as stream: - async for chunk in stream: - # Log the time taken to start response - if not response_started: - response_started = True - logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") - # Skip empty chunks - if chunk.type != "content_block_delta": - continue - # Handle streamed response chunk - response_chunk: ResponseWithThought = None - if chunk.delta.type == "text_delta": - response_chunk = ResponseWithThought(response=chunk.delta.text) - aggregated_response += chunk.delta.text - if chunk.delta.type == "thinking_delta": - response_chunk = ResponseWithThought(thought=chunk.delta.thinking) - # Handle streamed response chunk - if response_chunk: - yield response_chunk - final_message = await stream.get_final_message() + aggregated_response = "" + response_started = False + final_message = None + start_time = perf_counter() + async with client.messages.stream( + messages=formatted_messages, + model=model_name, # type: ignore + temperature=temperature, + system=system_prompt, + timeout=20, + max_tokens=max_tokens, + **model_kwargs, + ) as stream: + async for chunk in stream: + # Log the time taken to start response + if not response_started: + response_started = True + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Skip empty chunks + if chunk.type != "content_block_delta": + continue + # Handle streamed response chunk + response_chunk: ResponseWithThought = None + if chunk.delta.type == "text_delta": + response_chunk = ResponseWithThought(response=chunk.delta.text) + aggregated_response += chunk.delta.text + if chunk.delta.type == "thinking_delta": + response_chunk = ResponseWithThought(thought=chunk.delta.thinking) + # Handle streamed response chunk + if response_chunk: + yield response_chunk + final_message = await stream.get_final_message() - # Log the time taken to stream the entire response - logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") + # Calculate cost of chat + input_tokens = final_message.usage.input_tokens + output_tokens = final_message.usage.output_tokens + cache_read_tokens = final_message.usage.cache_read_input_tokens + cache_write_tokens = final_message.usage.cache_creation_input_tokens + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage") + ) - # Calculate cost of chat - input_tokens = final_message.usage.input_tokens - output_tokens = final_message.usage.output_tokens - cache_read_tokens = final_message.usage.cache_read_input_tokens - cache_write_tokens = final_message.usage.cache_creation_input_tokens - tracer["usage"] = get_chat_usage_metrics( - model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage") - ) + # Validate the response. If empty, raise an error to retry. + if is_none_or_empty(aggregated_response): + logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.") + raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.") - # Save conversation trace - tracer["chat_model"] = model_name - tracer["temperature"] = temperature - if is_promptrace_enabled(): - commit_conversation_trace(messages, aggregated_response, tracer) - except Exception as e: - logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True) + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if is_promptrace_enabled(): + commit_conversation_trace(messages, aggregated_response, tracer) def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None): diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 22cd8d0f..ede54ba3 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -73,6 +73,9 @@ def _is_retryable_error(exception: BaseException) -> bool: # client errors if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError): return True + # validation errors + if isinstance(exception, ValueError): + return True return False @@ -84,8 +87,8 @@ def _is_retryable_error(exception: BaseException) -> bool: reraise=True, ) def gemini_completion_with_backoff( - messages, - system_prompt, + messages: list[ChatMessage], + system_prompt: str, model_name: str, temperature=1.0, api_key=None, @@ -144,6 +147,11 @@ def gemini_completion_with_backoff( model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") ) + # Validate the response. If empty, raise an error to retry. + if is_none_or_empty(response_text): + logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.") + raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.") + # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature @@ -157,89 +165,90 @@ def gemini_completion_with_backoff( retry=retry_if_exception(_is_retryable_error), wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=False, ) async def gemini_chat_completion_with_backoff( - messages, - model_name, - temperature, - api_key, - api_base_url, - system_prompt, + messages: list[ChatMessage], + model_name: str, + temperature: float, + api_key: str, + api_base_url: str, + system_prompt: str, model_kwargs=None, deepthought=False, tracer: dict = {}, ) -> AsyncGenerator[str, None]: - try: - client = gemini_clients.get(api_key) - if not client: - client = get_gemini_client(api_key, api_base_url) - gemini_clients[api_key] = client + client = gemini_clients.get(api_key) + if not client: + client = get_gemini_client(api_key, api_base_url) + gemini_clients[api_key] = client - formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt) + formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt) - thinking_config = None - if deepthought and model_name.startswith("gemini-2-5"): - thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI) + thinking_config = None + if deepthought and model_name.startswith("gemini-2-5"): + thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI) - seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None - config = gtypes.GenerateContentConfig( - system_instruction=system_prompt, - temperature=temperature, - thinking_config=thinking_config, - max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, - stop_sequences=["Notes:\n["], - safety_settings=SAFETY_SETTINGS, - seed=seed, - http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}), - ) + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None + config = gtypes.GenerateContentConfig( + system_instruction=system_prompt, + temperature=temperature, + thinking_config=thinking_config, + max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, + stop_sequences=["Notes:\n["], + safety_settings=SAFETY_SETTINGS, + seed=seed, + http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}), + ) - aggregated_response = "" - final_chunk = None - response_started = False - start_time = perf_counter() - chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream( - model=model_name, config=config, contents=formatted_messages - ) - async for chunk in chat_stream: - # Log the time taken to start response - if not response_started: - response_started = True - logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") - # Keep track of the last chunk for usage data - final_chunk = chunk - # Handle streamed response chunk - message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) - message = message or chunk.text - aggregated_response += message - yield message - if stopped: - raise ValueError(message) + aggregated_response = "" + final_chunk = None + response_started = False + start_time = perf_counter() + chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream( + model=model_name, config=config, contents=formatted_messages + ) + async for chunk in chat_stream: + # Log the time taken to start response + if not response_started: + response_started = True + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Keep track of the last chunk for usage data + final_chunk = chunk + # Handle streamed response chunk + stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) + message = stop_message or chunk.text + aggregated_response += message + yield message + if stopped: + logger.warning( + f"LLM Response Prevented for {model_name}: {stop_message}.\n" + + f"Last Message by {messages[-1].role}: {messages[-1].content}" + ) + break - # Log the time taken to stream the entire response - logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") + # Calculate cost of chat + input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0 + output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0 + thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0 + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") + ) - # Calculate cost of chat - input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0 - output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0 - thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0 - tracer["usage"] = get_chat_usage_metrics( - model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") - ) + # Validate the response. If empty, raise an error to retry. + if is_none_or_empty(aggregated_response): + logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.") + raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.") - # Save conversation trace - tracer["chat_model"] = model_name - tracer["temperature"] = temperature - if is_promptrace_enabled(): - commit_conversation_trace(messages, aggregated_response, tracer) - except ValueError as e: - logger.warning( - f"LLM Response Prevented for {model_name}: {e.args[0]}.\n" - + f"Last Message by {messages[-1].role}: {messages[-1].content}" - ) - except Exception as e: - logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True) + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if is_promptrace_enabled(): + commit_conversation_trace(messages, aggregated_response, tracer) def handle_gemini_response( diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index bc3741ef..fe0a2686 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -38,6 +38,7 @@ from khoj.utils.helpers import ( get_chat_usage_metrics, get_openai_async_client, get_openai_client, + is_none_or_empty, is_promptrace_enabled, ) @@ -54,6 +55,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {} | retry_if_exception_type(openai._exceptions.APIConnectionError) | retry_if_exception_type(openai._exceptions.RateLimitError) | retry_if_exception_type(openai._exceptions.APIStatusError) + | retry_if_exception_type(ValueError) ), wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(3), @@ -136,6 +138,11 @@ def completion_with_backoff( model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost ) + # Validate the response. If empty, raise an error to retry. + if is_none_or_empty(aggregated_response): + logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.") + raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.") + # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature @@ -152,14 +159,15 @@ def completion_with_backoff( | retry_if_exception_type(openai._exceptions.APIConnectionError) | retry_if_exception_type(openai._exceptions.RateLimitError) | retry_if_exception_type(openai._exceptions.APIStatusError) + | retry_if_exception_type(ValueError) ), wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=False, ) async def chat_completion_with_backoff( - messages, + messages: list[ChatMessage], model_name: str, temperature, openai_api_key=None, @@ -168,120 +176,122 @@ async def chat_completion_with_backoff( model_kwargs: dict = {}, tracer: dict = {}, ) -> AsyncGenerator[ResponseWithThought, None]: - try: - client_key = f"{openai_api_key}--{api_base_url}" - client = openai_async_clients.get(client_key) - if not client: - client = get_openai_async_client(openai_api_key, api_base_url) - openai_async_clients[client_key] = client + client_key = f"{openai_api_key}--{api_base_url}" + client = openai_async_clients.get(client_key) + if not client: + client = get_openai_async_client(openai_api_key, api_base_url) + openai_async_clients[client_key] = client - stream_processor = adefault_stream_processor - formatted_messages = format_message_for_api(messages, api_base_url) + stream_processor = adefault_stream_processor + formatted_messages = format_message_for_api(messages, api_base_url) - # Configure thinking for openai reasoning models - if is_openai_reasoning_model(model_name, api_base_url): - temperature = 1 - reasoning_effort = "medium" if deepthought else "low" - model_kwargs["reasoning_effort"] = reasoning_effort - model_kwargs.pop("stop", None) # Remove unsupported stop param for reasoning models + # Configure thinking for openai reasoning models + if is_openai_reasoning_model(model_name, api_base_url): + temperature = 1 + reasoning_effort = "medium" if deepthought else "low" + model_kwargs["reasoning_effort"] = reasoning_effort + model_kwargs.pop("stop", None) # Remove unsupported stop param for reasoning models - # Get the first system message and add the string `Formatting re-enabled` to it. - # See https://platform.openai.com/docs/guides/reasoning-best-practices - if len(formatted_messages) > 0: - system_messages = [ - (i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system" - ] - if len(system_messages) > 0: - first_system_message_index, first_system_message = system_messages[0] - first_system_message_content = first_system_message["content"] - formatted_messages[first_system_message_index][ - "content" - ] = f"{first_system_message_content}\nFormatting re-enabled" - elif is_twitter_reasoning_model(model_name, api_base_url): - stream_processor = adeepseek_stream_processor - reasoning_effort = "high" if deepthought else "low" - model_kwargs["reasoning_effort"] = reasoning_effort - elif model_name.startswith("deepseek-reasoner"): - stream_processor = adeepseek_stream_processor - # Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role. - # The first message should always be a user message (except system message). - updated_messages: List[dict] = [] - for i, message in enumerate(formatted_messages): - if i > 0 and message["role"] == formatted_messages[i - 1]["role"]: - updated_messages[-1]["content"] += " " + message["content"] - elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant": - updated_messages[-1]["content"] += " " + message["content"] - else: - updated_messages.append(message) - formatted_messages = updated_messages - elif is_qwen_reasoning_model(model_name, api_base_url): - stream_processor = partial(ain_stream_thought_processor, thought_tag="think") - # Reasoning is enabled by default. Disable when deepthought is False. - # See https://qwenlm.github.io/blog/qwen3/#advanced-usages - if not deepthought and len(formatted_messages) > 0: - formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think" + # Get the first system message and add the string `Formatting re-enabled` to it. + # See https://platform.openai.com/docs/guides/reasoning-best-practices + if len(formatted_messages) > 0: + system_messages = [ + (i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system" + ] + if len(system_messages) > 0: + first_system_message_index, first_system_message = system_messages[0] + first_system_message_content = first_system_message["content"] + formatted_messages[first_system_message_index][ + "content" + ] = f"{first_system_message_content}\nFormatting re-enabled" + elif is_twitter_reasoning_model(model_name, api_base_url): + stream_processor = adeepseek_stream_processor + reasoning_effort = "high" if deepthought else "low" + model_kwargs["reasoning_effort"] = reasoning_effort + elif model_name.startswith("deepseek-reasoner"): + stream_processor = adeepseek_stream_processor + # Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role. + # The first message should always be a user message (except system message). + updated_messages: List[dict] = [] + for i, message in enumerate(formatted_messages): + if i > 0 and message["role"] == formatted_messages[i - 1]["role"]: + updated_messages[-1]["content"] += " " + message["content"] + elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant": + updated_messages[-1]["content"] += " " + message["content"] + else: + updated_messages.append(message) + formatted_messages = updated_messages + elif is_qwen_reasoning_model(model_name, api_base_url): + stream_processor = partial(ain_stream_thought_processor, thought_tag="think") + # Reasoning is enabled by default. Disable when deepthought is False. + # See https://qwenlm.github.io/blog/qwen3/#advanced-usages + if not deepthought and len(formatted_messages) > 0: + formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think" - stream = True - read_timeout = 300 if is_local_api(api_base_url) else 60 - model_kwargs["stream_options"] = {"include_usage": True} - if os.getenv("KHOJ_LLM_SEED"): - model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) + stream = True + read_timeout = 300 if is_local_api(api_base_url) else 60 + model_kwargs["stream_options"] = {"include_usage": True} + if os.getenv("KHOJ_LLM_SEED"): + model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) - aggregated_response = "" - final_chunk = None - response_started = False - start_time = perf_counter() - chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( - messages=formatted_messages, # type: ignore - model=model_name, - stream=stream, - temperature=temperature, - timeout=httpx.Timeout(30, read=read_timeout), - **model_kwargs, - ) - async for chunk in stream_processor(chat_stream): - # Log the time taken to start response - if not response_started: - response_started = True - logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") - # Keep track of the last chunk for usage data - final_chunk = chunk - # Skip empty chunks - if len(chunk.choices) == 0: - continue - # Handle streamed response chunk - response_chunk: ResponseWithThought = None - response_delta = chunk.choices[0].delta - if response_delta.content: - response_chunk = ResponseWithThought(response=response_delta.content) - aggregated_response += response_chunk.response - elif response_delta.thought: - response_chunk = ResponseWithThought(thought=response_delta.thought) - if response_chunk: - yield response_chunk + aggregated_response = "" + final_chunk = None + response_started = False + start_time = perf_counter() + chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( + messages=formatted_messages, # type: ignore + model=model_name, + stream=stream, + temperature=temperature, + timeout=httpx.Timeout(30, read=read_timeout), + **model_kwargs, + ) + async for chunk in stream_processor(chat_stream): + # Log the time taken to start response + if not response_started: + response_started = True + logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") + # Keep track of the last chunk for usage data + final_chunk = chunk + # Skip empty chunks + if len(chunk.choices) == 0: + continue + # Handle streamed response chunk + response_chunk: ResponseWithThought = None + response_delta = chunk.choices[0].delta + if response_delta.content: + response_chunk = ResponseWithThought(response=response_delta.content) + aggregated_response += response_chunk.response + elif response_delta.thought: + response_chunk = ResponseWithThought(thought=response_delta.thought) + if response_chunk: + yield response_chunk - # Log the time taken to stream the entire response - logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") + # Calculate cost of chat after stream finishes + input_tokens, output_tokens, cost = 0, 0, 0 + if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage: + input_tokens = final_chunk.usage.prompt_tokens + output_tokens = final_chunk.usage.completion_tokens + # Estimated costs returned by DeepInfra API + if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra: + cost = final_chunk.usage.model_extra.get("estimated_cost", 0) + tracer["usage"] = get_chat_usage_metrics( + model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost + ) - # Calculate cost of chat after stream finishes - input_tokens, output_tokens, cost = 0, 0, 0 - if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage: - input_tokens = final_chunk.usage.prompt_tokens - output_tokens = final_chunk.usage.completion_tokens - # Estimated costs returned by DeepInfra API - if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra: - cost = final_chunk.usage.model_extra.get("estimated_cost", 0) + # Validate the response. If empty, raise an error to retry. + if is_none_or_empty(aggregated_response): + logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.") + raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.") - # Save conversation trace - tracer["chat_model"] = model_name - tracer["temperature"] = temperature - tracer["usage"] = get_chat_usage_metrics( - model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost - ) - if is_promptrace_enabled(): - commit_conversation_trace(messages, aggregated_response, tracer) - except Exception as e: - logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True) + # Log the time taken to stream the entire response + logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if is_promptrace_enabled(): + commit_conversation_trace(messages, aggregated_response, tracer) def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: