mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Use Gemini suggested retry backoff if set. Improve gemini error handling
This commit is contained in:
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from time import perf_counter
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List
|
||||
@@ -13,6 +14,7 @@ from google.genai import types as gtypes
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
RetryCallState,
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception,
|
||||
@@ -73,7 +75,7 @@ SAFETY_SETTINGS = [
|
||||
def _is_retryable_error(exception: BaseException) -> bool:
|
||||
"""Check if the exception is a retryable error"""
|
||||
# server errors
|
||||
if isinstance(exception, gerrors.APIError):
|
||||
if isinstance(exception, (gerrors.APIError, gerrors.ClientError)):
|
||||
return exception.code in [429, 502, 503, 504]
|
||||
# client errors
|
||||
if (
|
||||
@@ -88,9 +90,48 @@ def _is_retryable_error(exception: BaseException) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _extract_retry_delay(exception: BaseException) -> float:
|
||||
"""Extract retry delay from Gemini error response, return in seconds"""
|
||||
if (
|
||||
isinstance(exception, (gerrors.ClientError, gerrors.APIError))
|
||||
and hasattr(exception, "details")
|
||||
and isinstance(exception.details, dict)
|
||||
):
|
||||
# Look for retryDelay key, value pair. E.g "retryDelay": "54s"
|
||||
if delay_str := exception.details.get("retryDelay"):
|
||||
delay_seconds_match = re.search(r"(\d+)s", delay_str)
|
||||
if delay_seconds_match:
|
||||
delay_seconds = float(delay_seconds_match.group(1))
|
||||
return delay_seconds
|
||||
return None
|
||||
|
||||
|
||||
def _wait_with_gemini_delay(min_wait=4, max_wait=120, multiplier=1, fallback_wait=None):
|
||||
"""Custom wait strategy that respects Gemini's retryDelay if present"""
|
||||
|
||||
def wait_func(retry_state: RetryCallState) -> float:
|
||||
# Use backoff time if last exception suggests a retry delay
|
||||
if retry_state.outcome and retry_state.outcome.failed:
|
||||
exception = retry_state.outcome.exception()
|
||||
gemini_delay = _extract_retry_delay(exception)
|
||||
if gemini_delay:
|
||||
# Use the Gemini-suggested delay, but cap it at max_wait
|
||||
suggested_delay = min(gemini_delay, max_wait)
|
||||
logger.info(f"Using Gemini suggested retry delay: {suggested_delay} seconds")
|
||||
return suggested_delay
|
||||
# Else use fallback backoff if provided
|
||||
if fallback_wait:
|
||||
return fallback_wait(retry_state)
|
||||
# Else use exponential backoff with provided parameters
|
||||
else:
|
||||
return wait_exponential(multiplier=multiplier, min=min_wait, max=max_wait)(retry_state)
|
||||
|
||||
return wait_func
|
||||
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception(_is_retryable_error),
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
wait=_wait_with_gemini_delay(min_wait=1, max_wait=10, fallback_wait=wait_random_exponential(min=1, max=10)),
|
||||
stop=stop_after_attempt(2),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
@@ -169,7 +210,14 @@ def gemini_completion_with_backoff(
|
||||
)
|
||||
except gerrors.ClientError as e:
|
||||
response = None
|
||||
response_text, _ = handle_gemini_response(e.args)
|
||||
# Handle 429 rate limit errors directly
|
||||
if e.code == 429:
|
||||
response_text = f"My brain is exhausted. Can you please try again in a bit?"
|
||||
# Log the full error details for debugging
|
||||
logger.error(f"Gemini ClientError: {e.code} {e.status}. Details: {e.details}")
|
||||
# Handle other errors
|
||||
else:
|
||||
response_text, _ = handle_gemini_response(e.args)
|
||||
# Respond with reason for stopping
|
||||
logger.warning(
|
||||
f"LLM Response Prevented for {model_name}: {response_text}.\n"
|
||||
@@ -206,7 +254,7 @@ def gemini_completion_with_backoff(
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception(_is_retryable_error),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
wait=_wait_with_gemini_delay(multiplier=1, min_wait=4, max_wait=10),
|
||||
stop=stop_after_attempt(3),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
reraise=False,
|
||||
@@ -310,6 +358,13 @@ def handle_gemini_response(
|
||||
candidates: list[gtypes.Candidate], prompt_feedback: gtypes.GenerateContentResponsePromptFeedback = None
|
||||
):
|
||||
"""Check if Gemini response was blocked and return an explanatory error message."""
|
||||
|
||||
# Ensure we have a proper list of candidates
|
||||
if not isinstance(candidates, list):
|
||||
message = f"\nUnexpected response format. Try again."
|
||||
stopped = True
|
||||
return message, stopped
|
||||
|
||||
# Check if the response was blocked due to safety concerns with the prompt
|
||||
if len(candidates) == 0 and prompt_feedback:
|
||||
message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query."
|
||||
@@ -428,7 +483,18 @@ def format_messages_for_gemini(
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
|
||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
||||
# Ensure messages are properly formatted for Content creation
|
||||
valid_messages = []
|
||||
for message in messages:
|
||||
try:
|
||||
# Try create Content object to validate the structure before adding to valid messages
|
||||
gtypes.Content(role=message.role, parts=message.content)
|
||||
valid_messages.append(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Dropping message with invalid content structure: {e}. Message: {message}")
|
||||
continue
|
||||
|
||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in valid_messages]
|
||||
return formatted_messages, system_prompt
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user