Retry on empty response or error in chat completion by llm over api

Previously all exceptions were being caught. So retry logic wasn't
getting triggered.

Exception catching had been added to close llm thread when threads
instead of async was being used for final response generation.

This isn't required anymore since moving to async. And we can now
re-enable retry on failures.

Raise error if response is empty to retry llm completion.
This commit is contained in:
Debanjum
2025-05-18 14:32:31 -07:00
parent 7827d317b4
commit cf55582852
3 changed files with 273 additions and 247 deletions

View File

@@ -1,6 +1,6 @@
import logging import logging
from time import perf_counter from time import perf_counter
from typing import Dict, List from typing import AsyncGenerator, Dict, List
import anthropic import anthropic
from langchain_core.messages.chat import ChatMessage from langchain_core.messages.chat import ChatMessage
@@ -100,6 +100,11 @@ def anthropic_completion_with_backoff(
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage") model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
) )
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(aggregated_response):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
@@ -112,8 +117,8 @@ def anthropic_completion_with_backoff(
@retry( @retry(
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(2), stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True, reraise=False,
) )
async def anthropic_chat_completion_with_backoff( async def anthropic_chat_completion_with_backoff(
messages: list[ChatMessage], messages: list[ChatMessage],
@@ -126,8 +131,7 @@ async def anthropic_chat_completion_with_backoff(
deepthought=False, deepthought=False,
model_kwargs=None, model_kwargs=None,
tracer={}, tracer={},
): ) -> AsyncGenerator[ResponseWithThought, None]:
try:
client = anthropic_async_clients.get(api_key) client = anthropic_async_clients.get(api_key)
if not client: if not client:
client = get_anthropic_async_client(api_key, api_base_url) client = get_anthropic_async_client(api_key, api_base_url)
@@ -176,9 +180,6 @@ async def anthropic_chat_completion_with_backoff(
yield response_chunk yield response_chunk
final_message = await stream.get_final_message() final_message = await stream.get_final_message()
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat # Calculate cost of chat
input_tokens = final_message.usage.input_tokens input_tokens = final_message.usage.input_tokens
output_tokens = final_message.usage.output_tokens output_tokens = final_message.usage.output_tokens
@@ -188,13 +189,19 @@ async def anthropic_chat_completion_with_backoff(
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage") model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
) )
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(aggregated_response):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
if is_promptrace_enabled(): if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True)
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None): def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):

View File

@@ -73,6 +73,9 @@ def _is_retryable_error(exception: BaseException) -> bool:
# client errors # client errors
if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError): if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError):
return True return True
# validation errors
if isinstance(exception, ValueError):
return True
return False return False
@@ -84,8 +87,8 @@ def _is_retryable_error(exception: BaseException) -> bool:
reraise=True, reraise=True,
) )
def gemini_completion_with_backoff( def gemini_completion_with_backoff(
messages, messages: list[ChatMessage],
system_prompt, system_prompt: str,
model_name: str, model_name: str,
temperature=1.0, temperature=1.0,
api_key=None, api_key=None,
@@ -144,6 +147,11 @@ def gemini_completion_with_backoff(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
) )
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(response_text):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
@@ -157,21 +165,20 @@ def gemini_completion_with_backoff(
retry=retry_if_exception(_is_retryable_error), retry=retry_if_exception(_is_retryable_error),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True, reraise=False,
) )
async def gemini_chat_completion_with_backoff( async def gemini_chat_completion_with_backoff(
messages, messages: list[ChatMessage],
model_name, model_name: str,
temperature, temperature: float,
api_key, api_key: str,
api_base_url, api_base_url: str,
system_prompt, system_prompt: str,
model_kwargs=None, model_kwargs=None,
deepthought=False, deepthought=False,
tracer: dict = {}, tracer: dict = {},
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
try:
client = gemini_clients.get(api_key) client = gemini_clients.get(api_key)
if not client: if not client:
client = get_gemini_client(api_key, api_base_url) client = get_gemini_client(api_key, api_base_url)
@@ -210,15 +217,16 @@ async def gemini_chat_completion_with_backoff(
# Keep track of the last chunk for usage data # Keep track of the last chunk for usage data
final_chunk = chunk final_chunk = chunk
# Handle streamed response chunk # Handle streamed response chunk
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text message = stop_message or chunk.text
aggregated_response += message aggregated_response += message
yield message yield message
if stopped: if stopped:
raise ValueError(message) logger.warning(
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
# Log the time taken to stream the entire response + f"Last Message by {messages[-1].role}: {messages[-1].content}"
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds") )
break
# Calculate cost of chat # Calculate cost of chat
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0 input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
@@ -228,18 +236,19 @@ async def gemini_chat_completion_with_backoff(
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage") model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
) )
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(aggregated_response):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
if is_promptrace_enabled(): if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
except ValueError as e:
logger.warning(
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
except Exception as e:
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
def handle_gemini_response( def handle_gemini_response(

View File

@@ -38,6 +38,7 @@ from khoj.utils.helpers import (
get_chat_usage_metrics, get_chat_usage_metrics,
get_openai_async_client, get_openai_async_client,
get_openai_client, get_openai_client,
is_none_or_empty,
is_promptrace_enabled, is_promptrace_enabled,
) )
@@ -54,6 +55,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
| retry_if_exception_type(openai._exceptions.APIConnectionError) | retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai._exceptions.RateLimitError) | retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai._exceptions.APIStatusError) | retry_if_exception_type(openai._exceptions.APIStatusError)
| retry_if_exception_type(ValueError)
), ),
wait=wait_random_exponential(min=1, max=10), wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
@@ -136,6 +138,11 @@ def completion_with_backoff(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
) )
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(aggregated_response):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
@@ -152,14 +159,15 @@ def completion_with_backoff(
| retry_if_exception_type(openai._exceptions.APIConnectionError) | retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai._exceptions.RateLimitError) | retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai._exceptions.APIStatusError) | retry_if_exception_type(openai._exceptions.APIStatusError)
| retry_if_exception_type(ValueError)
), ),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True, reraise=False,
) )
async def chat_completion_with_backoff( async def chat_completion_with_backoff(
messages, messages: list[ChatMessage],
model_name: str, model_name: str,
temperature, temperature,
openai_api_key=None, openai_api_key=None,
@@ -168,7 +176,6 @@ async def chat_completion_with_backoff(
model_kwargs: dict = {}, model_kwargs: dict = {},
tracer: dict = {}, tracer: dict = {},
) -> AsyncGenerator[ResponseWithThought, None]: ) -> AsyncGenerator[ResponseWithThought, None]:
try:
client_key = f"{openai_api_key}--{api_base_url}" client_key = f"{openai_api_key}--{api_base_url}"
client = openai_async_clients.get(client_key) client = openai_async_clients.get(client_key)
if not client: if not client:
@@ -260,9 +267,6 @@ async def chat_completion_with_backoff(
if response_chunk: if response_chunk:
yield response_chunk yield response_chunk
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Calculate cost of chat after stream finishes # Calculate cost of chat after stream finishes
input_tokens, output_tokens, cost = 0, 0, 0 input_tokens, output_tokens, cost = 0, 0, 0
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage: if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
@@ -271,17 +275,23 @@ async def chat_completion_with_backoff(
# Estimated costs returned by DeepInfra API # Estimated costs returned by DeepInfra API
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra: if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
cost = final_chunk.usage.model_extra.get("estimated_cost", 0) cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
)
# Validate the response. If empty, raise an error to retry.
if is_none_or_empty(aggregated_response):
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
# Log the time taken to stream the entire response
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
tracer["usage"] = get_chat_usage_metrics(
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
)
if is_promptrace_enabled(): if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport: