mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 05:39:11 +00:00
Retry on empty response or error in chat completion by llm over api
Previously all exceptions were being caught. So retry logic wasn't getting triggered. Exception catching had been added to close llm thread when threads instead of async was being used for final response generation. This isn't required anymore since moving to async. And we can now re-enable retry on failures. Raise error if response is empty to retry llm completion.
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from typing import Dict, List
|
from typing import AsyncGenerator, Dict, List
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
from langchain_core.messages.chat import ChatMessage
|
from langchain_core.messages.chat import ChatMessage
|
||||||
@@ -100,6 +100,11 @@ def anthropic_completion_with_backoff(
|
|||||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate the response. If empty, raise an error to retry.
|
||||||
|
if is_none_or_empty(aggregated_response):
|
||||||
|
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||||
|
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
@@ -112,8 +117,8 @@ def anthropic_completion_with_backoff(
|
|||||||
@retry(
|
@retry(
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
stop=stop_after_attempt(2),
|
stop=stop_after_attempt(2),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
reraise=True,
|
reraise=False,
|
||||||
)
|
)
|
||||||
async def anthropic_chat_completion_with_backoff(
|
async def anthropic_chat_completion_with_backoff(
|
||||||
messages: list[ChatMessage],
|
messages: list[ChatMessage],
|
||||||
@@ -126,8 +131,7 @@ async def anthropic_chat_completion_with_backoff(
|
|||||||
deepthought=False,
|
deepthought=False,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
tracer={},
|
tracer={},
|
||||||
):
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
try:
|
|
||||||
client = anthropic_async_clients.get(api_key)
|
client = anthropic_async_clients.get(api_key)
|
||||||
if not client:
|
if not client:
|
||||||
client = get_anthropic_async_client(api_key, api_base_url)
|
client = get_anthropic_async_client(api_key, api_base_url)
|
||||||
@@ -176,9 +180,6 @@ async def anthropic_chat_completion_with_backoff(
|
|||||||
yield response_chunk
|
yield response_chunk
|
||||||
final_message = await stream.get_final_message()
|
final_message = await stream.get_final_message()
|
||||||
|
|
||||||
# Log the time taken to stream the entire response
|
|
||||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
|
||||||
|
|
||||||
# Calculate cost of chat
|
# Calculate cost of chat
|
||||||
input_tokens = final_message.usage.input_tokens
|
input_tokens = final_message.usage.input_tokens
|
||||||
output_tokens = final_message.usage.output_tokens
|
output_tokens = final_message.usage.output_tokens
|
||||||
@@ -188,13 +189,19 @@ async def anthropic_chat_completion_with_backoff(
|
|||||||
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
model_name, input_tokens, output_tokens, cache_read_tokens, cache_write_tokens, usage=tracer.get("usage")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate the response. If empty, raise an error to retry.
|
||||||
|
if is_none_or_empty(aggregated_response):
|
||||||
|
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||||
|
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||||
|
|
||||||
|
# Log the time taken to stream the entire response
|
||||||
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
|
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):
|
||||||
|
|||||||
@@ -73,6 +73,9 @@ def _is_retryable_error(exception: BaseException) -> bool:
|
|||||||
# client errors
|
# client errors
|
||||||
if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError):
|
if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError):
|
||||||
return True
|
return True
|
||||||
|
# validation errors
|
||||||
|
if isinstance(exception, ValueError):
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@@ -84,8 +87,8 @@ def _is_retryable_error(exception: BaseException) -> bool:
|
|||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def gemini_completion_with_backoff(
|
def gemini_completion_with_backoff(
|
||||||
messages,
|
messages: list[ChatMessage],
|
||||||
system_prompt,
|
system_prompt: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
@@ -144,6 +147,11 @@ def gemini_completion_with_backoff(
|
|||||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate the response. If empty, raise an error to retry.
|
||||||
|
if is_none_or_empty(response_text):
|
||||||
|
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||||
|
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
@@ -157,21 +165,20 @@ def gemini_completion_with_backoff(
|
|||||||
retry=retry_if_exception(_is_retryable_error),
|
retry=retry_if_exception(_is_retryable_error),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
reraise=True,
|
reraise=False,
|
||||||
)
|
)
|
||||||
async def gemini_chat_completion_with_backoff(
|
async def gemini_chat_completion_with_backoff(
|
||||||
messages,
|
messages: list[ChatMessage],
|
||||||
model_name,
|
model_name: str,
|
||||||
temperature,
|
temperature: float,
|
||||||
api_key,
|
api_key: str,
|
||||||
api_base_url,
|
api_base_url: str,
|
||||||
system_prompt,
|
system_prompt: str,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
deepthought=False,
|
deepthought=False,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
try:
|
|
||||||
client = gemini_clients.get(api_key)
|
client = gemini_clients.get(api_key)
|
||||||
if not client:
|
if not client:
|
||||||
client = get_gemini_client(api_key, api_base_url)
|
client = get_gemini_client(api_key, api_base_url)
|
||||||
@@ -210,15 +217,16 @@ async def gemini_chat_completion_with_backoff(
|
|||||||
# Keep track of the last chunk for usage data
|
# Keep track of the last chunk for usage data
|
||||||
final_chunk = chunk
|
final_chunk = chunk
|
||||||
# Handle streamed response chunk
|
# Handle streamed response chunk
|
||||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
stop_message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||||
message = message or chunk.text
|
message = stop_message or chunk.text
|
||||||
aggregated_response += message
|
aggregated_response += message
|
||||||
yield message
|
yield message
|
||||||
if stopped:
|
if stopped:
|
||||||
raise ValueError(message)
|
logger.warning(
|
||||||
|
f"LLM Response Prevented for {model_name}: {stop_message}.\n"
|
||||||
# Log the time taken to stream the entire response
|
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
||||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
)
|
||||||
|
break
|
||||||
|
|
||||||
# Calculate cost of chat
|
# Calculate cost of chat
|
||||||
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
input_tokens = final_chunk.usage_metadata.prompt_token_count or 0 if final_chunk else 0
|
||||||
@@ -228,18 +236,19 @@ async def gemini_chat_completion_with_backoff(
|
|||||||
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
model_name, input_tokens, output_tokens, thought_tokens=thought_tokens, usage=tracer.get("usage")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate the response. If empty, raise an error to retry.
|
||||||
|
if is_none_or_empty(aggregated_response):
|
||||||
|
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||||
|
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||||
|
|
||||||
|
# Log the time taken to stream the entire response
|
||||||
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except ValueError as e:
|
|
||||||
logger.warning(
|
|
||||||
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
|
|
||||||
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in gemini_chat_completion_with_backoff stream: {e}", exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
def handle_gemini_response(
|
def handle_gemini_response(
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from khoj.utils.helpers import (
|
|||||||
get_chat_usage_metrics,
|
get_chat_usage_metrics,
|
||||||
get_openai_async_client,
|
get_openai_async_client,
|
||||||
get_openai_client,
|
get_openai_client,
|
||||||
|
is_none_or_empty,
|
||||||
is_promptrace_enabled,
|
is_promptrace_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -54,6 +55,7 @@ openai_async_clients: Dict[str, openai.AsyncOpenAI] = {}
|
|||||||
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
||||||
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
||||||
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
||||||
|
| retry_if_exception_type(ValueError)
|
||||||
),
|
),
|
||||||
wait=wait_random_exponential(min=1, max=10),
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
@@ -136,6 +138,11 @@ def completion_with_backoff(
|
|||||||
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate the response. If empty, raise an error to retry.
|
||||||
|
if is_none_or_empty(aggregated_response):
|
||||||
|
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||||
|
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
@@ -152,14 +159,15 @@ def completion_with_backoff(
|
|||||||
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
||||||
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
||||||
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
||||||
|
| retry_if_exception_type(ValueError)
|
||||||
),
|
),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
reraise=True,
|
reraise=False,
|
||||||
)
|
)
|
||||||
async def chat_completion_with_backoff(
|
async def chat_completion_with_backoff(
|
||||||
messages,
|
messages: list[ChatMessage],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature,
|
temperature,
|
||||||
openai_api_key=None,
|
openai_api_key=None,
|
||||||
@@ -168,7 +176,6 @@ async def chat_completion_with_backoff(
|
|||||||
model_kwargs: dict = {},
|
model_kwargs: dict = {},
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||||
try:
|
|
||||||
client_key = f"{openai_api_key}--{api_base_url}"
|
client_key = f"{openai_api_key}--{api_base_url}"
|
||||||
client = openai_async_clients.get(client_key)
|
client = openai_async_clients.get(client_key)
|
||||||
if not client:
|
if not client:
|
||||||
@@ -260,9 +267,6 @@ async def chat_completion_with_backoff(
|
|||||||
if response_chunk:
|
if response_chunk:
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
||||||
# Log the time taken to stream the entire response
|
|
||||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
|
||||||
|
|
||||||
# Calculate cost of chat after stream finishes
|
# Calculate cost of chat after stream finishes
|
||||||
input_tokens, output_tokens, cost = 0, 0, 0
|
input_tokens, output_tokens, cost = 0, 0, 0
|
||||||
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
|
if final_chunk and hasattr(final_chunk, "usage") and final_chunk.usage:
|
||||||
@@ -271,17 +275,23 @@ async def chat_completion_with_backoff(
|
|||||||
# Estimated costs returned by DeepInfra API
|
# Estimated costs returned by DeepInfra API
|
||||||
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
|
if final_chunk.usage.model_extra and "estimated_cost" in final_chunk.usage.model_extra:
|
||||||
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
|
cost = final_chunk.usage.model_extra.get("estimated_cost", 0)
|
||||||
|
tracer["usage"] = get_chat_usage_metrics(
|
||||||
|
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate the response. If empty, raise an error to retry.
|
||||||
|
if is_none_or_empty(aggregated_response):
|
||||||
|
logger.warning(f"No response by {model_name}\nLast Message by {messages[-1].role}: {messages[-1].content}.")
|
||||||
|
raise ValueError(f"Empty or no response by {model_name} over API. Retry if needed.")
|
||||||
|
|
||||||
|
# Log the time taken to stream the entire response
|
||||||
|
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||||
|
|
||||||
# Save conversation trace
|
# Save conversation trace
|
||||||
tracer["chat_model"] = model_name
|
tracer["chat_model"] = model_name
|
||||||
tracer["temperature"] = temperature
|
tracer["temperature"] = temperature
|
||||||
tracer["usage"] = get_chat_usage_metrics(
|
|
||||||
model_name, input_tokens, output_tokens, usage=tracer.get("usage"), cost=cost
|
|
||||||
)
|
|
||||||
if is_promptrace_enabled():
|
if is_promptrace_enabled():
|
||||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in chat_completion_with_backoff stream: {e}", exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
|
||||||
|
|||||||
Reference in New Issue
Block a user