Retry on empty response or error in chat completion by llm over api

Previously all exceptions were being caught. So retry logic wasn't
getting triggered.

Exception catching had been added to close llm thread when threads
instead of async was being used for final response generation.

This isn't required anymore since moving to async. And we can now
re-enable retry on failures.

Raise error if response is empty to retry llm completion.
This commit is contained in:
Debanjum
2025-05-18 14:32:31 -07:00
parent 7827d317b4
commit cf55582852
3 changed files with 273 additions and 247 deletions

View File

@@ -1,6 +1,6 @@
import logging import logging
from time import perf_counter from time import perf_counter
from typing import Dict, List from typing import AsyncGenerator, Dict, List
import anthropic import anthropic
from langchain_core.messages.chat import ChatMessage from langchain_core.messages.chat import ChatMessage
@@ -100,6 +100,11 @@ def anthropic_completion_with_backoff(
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage") model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
) )
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(aggregated_response):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
@@ -112,8 +117,8 @@ def anthropic_completion_with_backoff(
@retry( @retry(
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(2), stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True, reraise=False,
) )
async def anthropic_chat_completion_with_backoff( async def anthropic_chat_completion_with_backoff(
messages: list[ChatMessage], messages: list[ChatMessage],
@@ -126,75 +131,77 @@ async def anthropic_chat_completion_with_backoff(
deepthought=False, deepthought=False,
model_kwargs=None, model_kwargs=None,
tracer={}, tracer={},
): ) -> AsyncGenerator[ResponseWithThought, None]:
try: client = anthropic_async_clients.get(api_key)
client = anthropic_async_clients.get(api_key) if not client:
if not client: client = get_anthropic_async_client(api_key, api_base_url)
client = get_anthropic_async_client(api_key, api_base_url) anthropic_async_clients[api_key] = client
anthropic_async_clients[api_key] = client
model_kwargs = model_kwargs or dict() model_kwargs = model_kwargs or dict()
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
if deepthought and model_name.startswith("claude-3-7"): if deepthought and model_name.startswith("claude-3-7"):
model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": MAX_REASONING_TOKENS_ANTHROPIC} model_kwargs["thinking"] = {"type": "enabled", "budget_tokens": MAX_REASONING_TOKENS_ANTHROPIC}
max_tokens += MAX_REASONING_TOKENS_ANTHROPIC max_tokens += MAX_REASONING_TOKENS_ANTHROPIC
# Temperature control not supported when using extended thinking # Temperature control not supported when using extended thinking
temperature = 1.0 temperature = 1.0
formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) formatted_messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
aggregated_response = "" aggregated_response = ""
response_started = False response_started = False
final_message = None final_message = None
start_time = perf_counter() start_time = perf_counter()
async with client.messages.stream( async with client.messages.stream(
messages=formatted_messages, messages=formatted_messages,
model=model_name, # type: ignore model=model_name, # type: ignore
temperature=temperature, temperature=temperature,
system=system_prompt, system=system_prompt,
timeout=20, timeout=20,
max_tokens=max_tokens, max_tokens=max_tokens,
**model_kwargs, **model_kwargs,
) as stream: ) as stream:
async for chunk in stream: async for chunk in stream:
# Log the time taken to start response # Log the time taken to start response
if not response_started: if not response_started:
response_started = True response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Skip empty chunks # Skip empty chunks
if chunk.type != "content_block_delta": if chunk.type != "content_block_delta":
continue continue
# Handle streamed response chunk # Handle streamed response chunk
response_chunk: ResponseWithThought = None response_chunk: ResponseWithThought = None
if chunk.delta.type == "text_delta": if chunk.delta.type == "text_delta":
response_chunk = ResponseWithThought(response=chunk.delta.text) response_chunk = ResponseWithThought(response=chunk.delta.text)
aggregated_response += chunk.delta.text aggregated_response += chunk.delta.text
if chunk.delta.type == "thinking_delta": if chunk.delta.type == "thinking_delta":
response_chunk = ResponseWithThought(thought=chunk.delta.thinking) response_chunk = ResponseWithThought(thought=chunk.delta.thinking)
# Handle streamed response chunk # Handle streamed response chunk
if response_chunk: if response_chunk:
yield response_chunk yield response_chunk
final_message = await stream.get_final_message() final_message = await stream.get_final_message()
# Log the time taken to stream the entire response # Calculate cost of chat
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") input_tokens = final_message.usage.input_tokens
output_tokens = final_message.usage.output_tokens
cache_read_tokens = final_message.usage.cache_read_input_tokens
cache_write_tokens = final_message.usage.cache_creation_input_tokens
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
)
# Calculate cost of chat # Validate the response. If empty, raise an error to retry.
input_tokens = final_message.usage.input_tokens if is_none_or_empty(aggregated_response):
output_tokens = final_message.usage.output_tokens logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
cache_read_tokens = final_message.usage.cache_read_input_tokens raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
cache_write_tokens = final_message.usage.cache_creation_input_tokens
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
)
# Save conversation trace # Log the time taken to stream the entire response
tracer["chat_model"] = model_name logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
tracer["temperature"] = temperature
if is_promptrace_enabled(): # Save conversation trace
commit_conversation_trace(messages, aggregated_response, tracer) tracer["chat_model"] = model_name
except Exception as e: tracer["temperature"] = temperature
logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True) if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer)
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None): def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):

View File

@@ -73,6 +73,9 @@ def _is_retryable_error(exception: BaseException) -> bool:
# client errors # client errors
if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError): if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError):
return True return True
# validation errors
if isinstance(exception, ValueError):
return True
return False return False
@@ -84,8 +87,8 @@ def _is_retryable_error(exception: BaseException) -> bool:
reraise=True, reraise=True,
) )
def gemini_completion_with_backoff( def gemini_completion_with_backoff(
messages, messages: list[ChatMessage],
system_prompt, system_prompt: str,
model_name: str, model_name: str,
temperature=1.0, temperature=1.0,
api_key=None, api_key=None,
@@ -144,6 +147,11 @@ def gemini_completion_with_backoff(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
) )
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(response_text):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
@@ -157,89 +165,90 @@ def gemini_completion_with_backoff(
retry=retry_if_exception(_is_retryable_error), retry=retry_if_exception(_is_retryable_error),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True, reraise=False,
) )
async def gemini_chat_completion_with_backoff( async def gemini_chat_completion_with_backoff(
messages, messages: list[ChatMessage],
model_name, model_name: str,
temperature, temperature: float,
api_key, api_key: str,
api_base_url, api_base_url: str,
system_prompt, system_prompt: str,
model_kwargs=None, model_kwargs=None,
deepthought=False, deepthought=False,
tracer: dict = {}, tracer: dict = {},
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
try: client = gemini_clients.get(api_key)
client = gemini_clients.get(api_key) if not client:
if not client: client = get_gemini_client(api_key, api_base_url)
client = get_gemini_client(api_key, api_base_url) gemini_clients[api_key] = client
gemini_clients[api_key] = client
formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt) formatted_messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
thinking_config = None thinking_config = None
if deepthought and model_name.startswith("gemini-2-5"): if deepthought and model_name.startswith("gemini-2-5"):
thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI) thinking_config = gtypes.ThinkingConfig(thinking_budget=MAX_REASONING_TOKENS_GEMINI)
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
config = gtypes.GenerateContentConfig( config = gtypes.GenerateContentConfig(
system_instruction=system_prompt, system_instruction=system_prompt,
temperature=temperature, temperature=temperature,
thinking_config=thinking_config, thinking_config=thinking_config,
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI, max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
stop_sequences=["Notes:\n["], stop_sequences=["Notes:\n["],
safety_settings=SAFETY_SETTINGS, safety_settings=SAFETY_SETTINGS,
seed=seed, seed=seed,
http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}), http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
) )
aggregated_response = "" aggregated_response = ""
final_chunk = None final_chunk = None
response_started = False response_started = False
start_time = perf_counter() start_time = perf_counter()
chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream( chat_stream: AsyncIterator[gtypes.GenerateContentResponse] = await client.aio.models.generate_content_stream(
model=model_name, config=config, contents=formatted_messages model=model_name, config=config, contents=formatted_messages
) )
async for chunk in chat_stream: async for chunk in chat_stream:
# Log the time taken to start response # Log the time taken to start response
if not response_started: if not response_started:
response_started = True response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data # Keep track of the last chunk for usage data
final_chunk = chunk final_chunk = chunk
# Handle streamed response chunk # Handle streamed response chunk
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text message = stop_message or chunk.text
aggregated_response += message aggregated_response += message
yield message yield message
if stopped: if stopped:
raise ValueError(message) logger.warning(
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
break
# Log the time taken to stream the entire response # Calculate cost of chat
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
)
# Calculate cost of chat # Validate the response. If empty, raise an error to retry.
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0 if is_none_or_empty(aggregated_response):
output_tokens = final_chunk.usage_metadata.candidates_token_count or 0 if final_chunk else 0 logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
thought_tokens = final_chunk.usage_metadata.thoughts_token_count or 0 if final_chunk else 0 raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
)
# Save conversation trace # Log the time taken to stream the entire response
tracer["chat_model"] = model_name logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
tracer["temperature"] = temperature
if is_promptrace_enabled(): # Save conversation trace
commit_conversation_trace(messages, aggregated_response, tracer) tracer["chat_model"] = model_name
except ValueError as e: tracer["temperature"] = temperature
logger.warning( if is_promptrace_enabled():
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n" commit_conversation_trace(messages, aggregated_response, tracer)
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
except Exception as e:
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
def handle_gemini_response( def handle_gemini_response(

View File

@@ -38,6 +38,7 @@ from khoj.utils.helpers import (
get_chat_usage_metrics, get_chat_usage_metrics,
get_openai_async_client, get_openai_async_client,
get_openai_client, get_openai_client,
is_none_or_empty,
is_promptrace_enabled, is_promptrace_enabled,
) )
@@ -54,6 +55,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
| retry_if_exception_type(openai._exceptions.APIConnectionError) | retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai._exceptions.RateLimitError) | retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai._exceptions.APIStatusError) | retry_if_exception_type(openai._exceptions.APIStatusError)
| retry_if_exception_type(ValueError)
), ),
wait=wait_random_exponential(min=1, max=10), wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
@@ -136,6 +138,11 @@ def completion_with_backoff(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
) )
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(aggregated_response):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
@@ -152,14 +159,15 @@ def completion_with_backoff(
| retry_if_exception_type(openai._exceptions.APIConnectionError) | retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai._exceptions.RateLimitError) | retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai._exceptions.APIStatusError) | retry_if_exception_type(openai._exceptions.APIStatusError)
| retry_if_exception_type(ValueError)
), ),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True, reraise=False,
) )
async def chat_completion_with_backoff( async def chat_completion_with_backoff(
messages, messages: list[ChatMessage],
model_name: str, model_name: str,
temperature, temperature,
openai_api_key=None, openai_api_key=None,
@@ -168,120 +176,122 @@ async def chat_completion_with_backoff(
model_kwargs: dict = {}, model_kwargs: dict = {},
tracer: dict = {}, tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]: ) -> AsyncGenerator[ResponseWithThought, None]:
try: client_key = f"{openai_api_key}--{api_base_url}"
client_key = f"{openai_api_key}--{api_base_url}" client = openai_async_clients.get(client_key)
client = openai_async_clients.get(client_key) if not client:
if not client: client = get_openai_async_client(openai_api_key, api_base_url)
client = get_openai_async_client(openai_api_key, api_base_url) openai_async_clients[client_key] = client
openai_async_clients[client_key] = client
stream_processor = adefault_stream_processor stream_processor = adefault_stream_processor
formatted_messages = format_message_for_api(messages, api_base_url) formatted_messages = format_message_for_api(messages, api_base_url)
# Configure thinking for openai reasoning models # Configure thinking for openai reasoning models
if is_openai_reasoning_model(model_name, api_base_url): if is_openai_reasoning_model(model_name, api_base_url):
temperature = 1 temperature = 1
reasoning_effort = "medium" if deepthought else "low" reasoning_effort = "medium" if deepthought else "low"
model_kwargs["reasoning_effort"] = reasoning_effort model_kwargs["reasoning_effort"] = reasoning_effort
model_kwargs.pop("stop", None) # Remove unsupported stop param for reasoning models model_kwargs.pop("stop", None) # Remove unsupported stop param for reasoning models
# Get the first system message and add the string `Formatting re-enabled` to it. # Get the first system message and add the string `Formatting re-enabled` to it.
# See https://platform.openai.com/docs/guides/reasoning-best-practices # See https://platform.openai.com/docs/guides/reasoning-best-practices
if len(formatted_messages) > 0: if len(formatted_messages) > 0:
system_messages = [ system_messages = [
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system" (i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
] ]
if len(system_messages) > 0: if len(system_messages) > 0:
first_system_message_index, first_system_message = system_messages[0] first_system_message_index, first_system_message = system_messages[0]
first_system_message_content = first_system_message["content"] first_system_message_content = first_system_message["content"]
formatted_messages[first_system_message_index][ formatted_messages[first_system_message_index][
"content" "content"
] = f"{first_system_message_content}\nFormatting re-enabled" ] = f"{first_system_message_content}\nFormatting re-enabled"
elif is_twitter_reasoning_model(model_name, api_base_url): elif is_twitter_reasoning_model(model_name, api_base_url):
stream_processor = adeepseek_stream_processor stream_processor = adeepseek_stream_processor
reasoning_effort = "high" if deepthought else "low" reasoning_effort = "high" if deepthought else "low"
model_kwargs["reasoning_effort"] = reasoning_effort model_kwargs["reasoning_effort"] = reasoning_effort
elif model_name.startswith("deepseek-reasoner"): elif model_name.startswith("deepseek-reasoner"):
stream_processor = adeepseek_stream_processor stream_processor = adeepseek_stream_processor
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role. # Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
# The first message should always be a user message (except system message). # The first message should always be a user message (except system message).
updated_messages: List[dict] = [] updated_messages: List[dict] = []
for i, message in enumerate(formatted_messages): for i, message in enumerate(formatted_messages):
if i > 0 and message["role"] == formatted_messages[i - 1]["role"]: if i > 0 and message["role"] == formatted_messages[i - 1]["role"]:
updated_messages[-1]["content"] += " " + message["content"] updated_messages[-1]["content"] += " " + message["content"]
elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant": elif i == 1 and formatted_messages[i - 1]["role"] == "system" and message["role"] == "assistant":
updated_messages[-1]["content"] += " " + message["content"] updated_messages[-1]["content"] += " " + message["content"]
else: else:
updated_messages.append(message) updated_messages.append(message)
formatted_messages = updated_messages formatted_messages = updated_messages
elif is_qwen_reasoning_model(model_name, api_base_url): elif is_qwen_reasoning_model(model_name, api_base_url):
stream_processor = partial(ain_stream_thought_processor, thought_tag="think") stream_processor = partial(ain_stream_thought_processor, thought_tag="think")
# Reasoning is enabled by default. Disable when deepthought is False. # Reasoning is enabled by default. Disable when deepthought is False.
# See https://qwenlm.github.io/blog/qwen3/#advanced-usages # See https://qwenlm.github.io/blog/qwen3/#advanced-usages
if not deepthought and len(formatted_messages) > 0: if not deepthought and len(formatted_messages) > 0:
formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think" formatted_messages[-1]["content"] = formatted_messages[-1]["content"] + " /no_think"
stream = True stream = True
read_timeout = 300 if is_local_api(api_base_url) else 60 read_timeout = 300 if is_local_api(api_base_url) else 60
model_kwargs["stream_options"] = {"include_usage": True} model_kwargs["stream_options"] = {"include_usage": True}
if os.getenv("KHOJ_LLM_SEED"): if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
aggregated_response = "" aggregated_response = ""
final_chunk = None final_chunk = None
response_started = False response_started = False
start_time = perf_counter() start_time = perf_counter()
chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( chat_stream: openai.AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
messages=formatted_messages, # type: ignore messages=formatted_messages, # type: ignore
model=model_name, model=model_name,
stream=stream, stream=stream,
temperature=temperature, temperature=temperature,
timeout=httpx.Timeout(30, read=read_timeout), timeout=httpx.Timeout(30, read=read_timeout),
**model_kwargs, **model_kwargs,
) )
async for chunk in stream_processor(chat_stream): async for chunk in stream_processor(chat_stream):
# Log the time taken to start response # Log the time taken to start response
if not response_started: if not response_started:
response_started = True response_started = True
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds") logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
# Keep track of the last chunk for usage data # Keep track of the last chunk for usage data
final_chunk = chunk final_chunk = chunk
# Skip empty chunks # Skip empty chunks
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
continue continue
# Handle streamed response chunk # Handle streamed response chunk
response_chunk: ResponseWithThought = None response_chunk: ResponseWithThought = None
response_delta = chunk.choices[0].delta response_delta = chunk.choices[0].delta
if response_delta.content: if response_delta.content:
response_chunk = ResponseWithThought(response=response_delta.content) response_chunk = ResponseWithThought(response=response_delta.content)
aggregated_response += response_chunk.response aggregated_response += response_chunk.response
elif response_delta.thought: elif response_delta.thought:
response_chunk = ResponseWithThought(thought=response_delta.thought) response_chunk = ResponseWithThought(thought=response_delta.thought)
if response_chunk: if response_chunk:
yield response_chunk yield response_chunk
# Log the time taken to stream the entire response # Calculate cost of chat after stream finishes
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") input_tokens, output_tokens, cost = 0, 0, 0
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
input_tokens = final_chunk.usage.prompt_tokens
output_tokens = final_chunk.usage.completion_tokens
# Estimated costs returned by DeepInfra API
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
)
# Calculate cost of chat after stream finishes # Validate the response. If empty, raise an error to retry.
input_tokens, output_tokens, cost = 0, 0, 0 if is_none_or_empty(aggregated_response):
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage: logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
input_tokens = final_chunk.usage.prompt_tokens raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
output_tokens = final_chunk.usage.completion_tokens
# Estimated costs returned by DeepInfra API
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
# Save conversation trace # Log the time taken to stream the entire response
tracer["chat_model"] = model_name logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
tracer["temperature"] = temperature
tracer["usage"] = get_chat_usage_metrics( # Save conversation trace
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost tracer["chat_model"] = model_name
) tracer["temperature"] = temperature
if is_promptrace_enabled(): if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: