Retry on empty response or error in chat completion by llm over api

Previously all exceptions were being caught. So retry logic wasn't
getting triggered.

Exception catching had been added to close llm thread when threads
instead of async was being used for final response generation.

This isn't required anymore since moving to async. And we can now
re-enable retry on failures.

Raise error if response is empty to retry llm completion.
This commit is contained in:
Debanjum
2025-05-18 14:32:31 -07:00
parent 7827d317b4
commit cf55582852
3 changed files with 273 additions and 247 deletions

View File

@@ -1,6 +1,6 @@
import logging
from time import perf_counter
from typing import Dict, List
from typing import AsyncGenerator, Dict, List
import anthropic
from langchain_core.messages.chat import ChatMessage
@@ -100,6 +100,11 @@ def anthropic_completion_with_backoff(
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
)
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(aggregated_response):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
@@ -112,8 +117,8 @@ def anthropic_completion_with_backoff(
@retry(
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=False,
)
async def anthropic_chat_completion_with_backoff(
messages: list[ChatMessage],
@@ -126,75 +131,77 @@ async def anthropic_chat_completion_with_backoff(
deepthought=False,
model_kwargs=None,
tracer={},
):
try:
client = anthropic_async_clients.get(api_key)
if not client:
client = get_anthropic_async_client(api_key, api_base_url)
anthropic_async_clients[api_key] = client
) -> AsyncGenerator[ResponseWithThought, None]:
client = anthropic_async_clients.get(api_key)
if not client:
client = get_anthropic_async_client(api_key, api_base_url)
anthropic_async_clients[api_key] = client
model_kwargs = model_kwargs or dict()
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
if deepthought and model_name.startswith("claude-3-7"):
model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": MAX_REASONING_TOKENS_ANTHROPIC}
max_tokens += MAX_REASONING_TOKENS_ANTHROPIC
# Temperature control not supported when using extended thinking
temperature = 1.0
model_kwargs = model_kwargs or dict()
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
if deepthought and model_name.startswith("claude-3-7"):
model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": MAX_REASONING_TOKENS_ANTHROPIC}
max_tokens += MAX_REASONING_TOKENS_ANTHROPIC
# Temperature control not supported when using extended thinking
temperature = 1.0
formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
aggregated_response = ""
response_started = False
final_message = None
start_time = perf_counter()
async with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
system=system_prompt,
timeout=20,
max_tokens=max_tokens,
**model_kwargs,
) as stream:
async for chunk in stream:
# Log the time taken to start response
if not response_started:
response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Skip empty chunks
if chunk.type != "content_block_delta":
continue
# Handle streamed response chunk
response_chunk: ResponseWithThought = None
if chunk.delta.type == "text_delta":
response_chunk = ResponseWithThought(response=chunk.delta.text)
aggregated_response += chunk.delta.text
if chunk.delta.type == "thinking_delta":
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
# Handle streamed response chunk
if response_chunk:
yield response_chunk
final_message = await stream.get_final_message()
aggregated_response = ""
response_started = False
final_message = None
start_time = perf_counter()
async with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
system=system_prompt,
timeout=20,
max_tokens=max_tokens,
**model_kwargs,
) as stream:
async for chunk in stream:
# Log the time taken to start response
if not response_started:
response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Skip empty chunks
if chunk.type != "content_block_delta":
continue
# Handle streamed response chunk
response_chunk: ResponseWithThought = None
if chunk.delta.type == "text_delta":
response_chunk = ResponseWithThought(response=chunk.delta.text)
aggregated_response += chunk.delta.text
if chunk.delta.type == "thinking_delta":
response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
# Handle streamed response chunk
if response_chunk:
yield response_chunk
final_message = await stream.get_final_message()
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat
input_tokens = final_message.usage.input_tokens
output_tokens = final_message.usage.output_tokens
cache_read_tokens = final_message.usage.cache_read_input_tokens
cache_write_tokens = final_message.usage.cache_creation_input_tokens
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
)
# Calculate cost of chat
input_tokens = final_message.usage.input_tokens
output_tokens = final_message.usage.output_tokens
cache_read_tokens = final_message.usage.cache_read_input_tokens
cache_write_tokens = final_message.usage.cache_creation_input_tokens
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
)
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(aggregated_response):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True)
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):

View File

@@ -73,6 +73,9 @@ def _is_retryable_error(exception: BaseException) -> bool:
# client errors
if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError):
return True
# validation errors
if isinstance(exception, ValueError):
return True
return False
@@ -84,8 +87,8 @@ def _is_retryable_error(exception: BaseException) -> bool:
reraise=True,
)
def gemini_completion_with_backoff(
messages,
system_prompt,
messages: list[ChatMessage],
system_prompt: str,
model_name: str,
temperature=1.0,
api_key=None,
@@ -144,6 +147,11 @@ def gemini_completion_with_backoff(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
)
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(response_text):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
@@ -157,89 +165,90 @@ def gemini_completion_with_backoff(
retry=retry_if_exception(_is_retryable_error),
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=False,
)
async def gemini_chat_completion_with_backoff(
messages,
model_name,
temperature,
api_key,
api_base_url,
system_prompt,
messages: list[ChatMessage],
model_name: str,
temperature: float,
api_key: str,
api_base_url: str,
system_prompt: str,
model_kwargs=None,
deepthought=False,
tracer: dict = {},
) -> AsyncGenerator[str, None]:
try:
client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
client = gemini_clients.get(api_key)
if not client:
client = get_gemini_client(api_key, api_base_url)
gemini_clients[api_key] = client
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
thinking_config = None
if deepthought and model_name.startswith("gemini-2-5"):
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
thinking_config = None
if deepthought and model_name.startswith("gemini-2-5"):
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
temperature=temperature,
thinking_config=thinking_config,
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
stop_sequences=["Notes:\n["],
safety_settings=SAFETY_SETTINGS,
seed=seed,
http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig(
system_instruction=system_prompt,
temperature=temperature,
thinking_config=thinking_config,
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
stop_sequences=["Notes:\n["],
safety_settings=SAFETY_SETTINGS,
seed=seed,
http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
)
aggregated_response = ""
final_chunk = None
response_started = False
start_time = perf_counter()
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
model=model_name, config=config, contents=formatted_messages
)
async for chunk in chat_stream:
# Log the time taken to start response
if not response_started:
response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Handle streamed response chunk
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text
aggregated_response += message
yield message
if stopped:
raise ValueError(message)
aggregated_response = ""
final_chunk = None
response_started = False
start_time = perf_counter()
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
model=model_name, config=config, contents=formatted_messages
)
async for chunk in chat_stream:
# Log the time taken to start response
if not response_started:
response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Handle streamed response chunk
stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = stop_message or chunk.text
aggregated_response += message
yield message
if stopped:
logger.warning(
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
break
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
)
# Calculate cost of chat
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
)
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(aggregated_response):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
except ValueError as e:
logger.warning(
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
except Exception as e:
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
def handle_gemini_response(

View File

@@ -38,6 +38,7 @@ from khoj.utils.helpers import (
get_chat_usage_metrics,
get_openai_async_client,
get_openai_client,
is_none_or_empty,
is_promptrace_enabled,
)
@@ -54,6 +55,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
| retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai._exceptions.APIStatusError)
| retry_if_exception_type(ValueError)
),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
@@ -136,6 +138,11 @@ def completion_with_backoff(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
)
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(aggregated_response):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
@@ -152,14 +159,15 @@ def completion_with_backoff(
| retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai._exceptions.APIStatusError)
| retry_if_exception_type(ValueError)
),
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=False,
)
async def chat_completion_with_backoff(
messages,
messages: list[ChatMessage],
model_name: str,
temperature,
openai_api_key=None,
@@ -168,120 +176,122 @@ async def chat_completion_with_backoff(
model_kwargs: dict = {},
tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]:
try:
client_key = f"{openai_api_key}--{api_base_url}"
client = openai_async_clients.get(client_key)
if not client:
client = get_openai_async_client(openai_api_key, api_base_url)
openai_async_clients[client_key] = client
client_key = f"{openai_api_key}--{api_base_url}"
client = openai_async_clients.get(client_key)
if not client:
client = get_openai_async_client(openai_api_key, api_base_url)
openai_async_clients[client_key] = client
stream_processor = adefault_stream_processor
formatted_messages = format_message_for_api(messages, api_base_url)
stream_processor = adefault_stream_processor
formatted_messages = format_message_for_api(messages, api_base_url)
# Configure thinking for openai reasoning models
if is_openai_reasoning_model(model_name, api_base_url):
temperature = 1
reasoning_effort = "medium" if deepthought else "low"
model_kwargs["reasoning_effort"] = reasoning_effort
model_kwargs.pop("stop", None) # Remove unsupported stop param for reasoning models
# Configure thinking for openai reasoning models
if is_openai_reasoning_model(model_name, api_base_url):
temperature = 1
reasoning_effort = "medium" if deepthought else "low"
model_kwargs["reasoning_effort"] = reasoning_effort
model_kwargs.pop("stop", None) # Remove unsupported stop param for reasoning models
# Get the first system message and add the string `Formatting re-enabled` to it.
# See https://platform.openai.com/docs/guides/reasoning-best-practices
if len(formatted_messages) > 0:
system_messages = [
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
]
if len(system_messages) > 0:
first_system_message_index, first_system_message = system_messages[0]
first_system_message_content = first_system_message["content"]
formatted_messages[first_system_message_index][
"content"
] = f"{first_system_message_content}\nFormatting re-enabled"
elif is_twitter_reasoning_model(model_name, api_base_url):
stream_processor = adeepseek_stream_processor
reasoning_effort = "high" if deepthought else "low"
model_kwargs["reasoning_effort"] = reasoning_effort
elif model_name.startswith("deepseek-reasoner"):
stream_processor = adeepseek_stream_processor
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
# The first message should always be a user message (except system message).
updated_messages: List[dict] = []
for i, message in enumerate(formatted_messages):
if i > 0 and message["role"] == formatted_messages[i - 1]["role"]:
updated_messages[-1]["content"] += " " + message["content"]
elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant":
updated_messages[-1]["content"] += " " + message["content"]
else:
updated_messages.append(message)
formatted_messages = updated_messages
elif is_qwen_reasoning_model(model_name, api_base_url):
stream_processor = partial(ain_stream_thought_processor, thought_tag="think")
# Reasoning is enabled by default. Disable when deepthought is False.
# See https://qwenlm.github.io/blog/qwen3/#advanced-usages
if not deepthought and len(formatted_messages) > 0:
formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
# Get the first system message and add the string `Formatting re-enabled` to it.
# See https://platform.openai.com/docs/guides/reasoning-best-practices
if len(formatted_messages) > 0:
system_messages = [
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
]
if len(system_messages) > 0:
first_system_message_index, first_system_message = system_messages[0]
first_system_message_content = first_system_message["content"]
formatted_messages[first_system_message_index][
"content"
] = f"{first_system_message_content}\nFormatting re-enabled"
elif is_twitter_reasoning_model(model_name, api_base_url):
stream_processor = adeepseek_stream_processor
reasoning_effort = "high" if deepthought else "low"
model_kwargs["reasoning_effort"] = reasoning_effort
elif model_name.startswith("deepseek-reasoner"):
stream_processor = adeepseek_stream_processor
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
# The first message should always be a user message (except system message).
updated_messages: List[dict] = []
for i, message in enumerate(formatted_messages):
if i > 0 and message["role"] == formatted_messages[i - 1]["role"]:
updated_messages[-1]["content"] += " " + message["content"]
elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant":
updated_messages[-1]["content"] += " " + message["content"]
else:
updated_messages.append(message)
formatted_messages = updated_messages
elif is_qwen_reasoning_model(model_name, api_base_url):
stream_processor = partial(ain_stream_thought_processor, thought_tag="think")
# Reasoning is enabled by default. Disable when deepthought is False.
# See https://qwenlm.github.io/blog/qwen3/#advanced-usages
if not deepthought and len(formatted_messages) > 0:
formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
stream = True
read_timeout = 300 if is_local_api(api_base_url) else 60
model_kwargs["stream_options"] = {"include_usage": True}
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
stream = True
read_timeout = 300 if is_local_api(api_base_url) else 60
model_kwargs["stream_options"] = {"include_usage": True}
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
aggregated_response = ""
final_chunk = None
response_started = False
start_time = perf_counter()
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
messages=formatted_messages, # type: ignore
model=model_name,
stream=stream,
temperature=temperature,
timeout=httpx.Timeout(30, read=read_timeout),
**model_kwargs,
)
async for chunk in stream_processor(chat_stream):
# Log the time taken to start response
if not response_started:
response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Skip empty chunks
if len(chunk.choices) == 0:
continue
# Handle streamed response chunk
response_chunk: ResponseWithThought = None
response_delta = chunk.choices[0].delta
if response_delta.content:
response_chunk = ResponseWithThought(response=response_delta.content)
aggregated_response += response_chunk.response
elif response_delta.thought:
response_chunk = ResponseWithThought(thought=response_delta.thought)
if response_chunk:
yield response_chunk
aggregated_response = ""
final_chunk = None
response_started = False
start_time = perf_counter()
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
messages=formatted_messages, # type: ignore
model=model_name,
stream=stream,
temperature=temperature,
timeout=httpx.Timeout(30, read=read_timeout),
**model_kwargs,
)
async for chunk in stream_processor(chat_stream):
# Log the time taken to start response
if not response_started:
response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data
final_chunk = chunk
# Skip empty chunks
if len(chunk.choices) == 0:
continue
# Handle streamed response chunk
response_chunk: ResponseWithThought = None
response_delta = chunk.choices[0].delta
if response_delta.content:
response_chunk = ResponseWithThought(response=response_delta.content)
aggregated_response += response_chunk.response
elif response_delta.thought:
response_chunk = ResponseWithThought(thought=response_delta.thought)
if response_chunk:
yield response_chunk
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat after stream finishes
input_tokens, output_tokens, cost = 0, 0, 0
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
input_tokens = final_chunk.usage.prompt_tokens
output_tokens = final_chunk.usage.completion_tokens
# Estimated costs returned by DeepInfra API
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
)
# Calculate cost of chat after stream finishes
input_tokens, output_tokens, cost = 0, 0, 0
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
input_tokens = final_chunk.usage.prompt_tokens
output_tokens = final_chunk.usage.completion_tokens
# Estimated costs returned by DeepInfra API
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(aggregated_response):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
)
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: