mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Use Gemini suggested retry backoff if set. Improve gemini error handling
This commit is contained in:
@@ -2,6 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List
|
||||||
@@ -13,6 +14,7 @@ from google.genai import types as gtypes
|
|||||||
from langchain_core.messages.chat import ChatMessage
|
from langchain_core.messages.chat import ChatMessage
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
|
RetryCallState,
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
retry,
|
retry,
|
||||||
retry_if_exception,
|
retry_if_exception,
|
||||||
@@ -73,7 +75,7 @@ SAFETY_SETTINGS = [
|
|||||||
def _is_retryable_error(exception: BaseException) -> bool:
|
def _is_retryable_error(exception: BaseException) -> bool:
|
||||||
"""Check if the exception is a retryable error"""
|
"""Check if the exception is a retryable error"""
|
||||||
# server errors
|
# server errors
|
||||||
if isinstance(exception, gerrors.APIError):
|
if isinstance(exception, (gerrors.APIError, gerrors.ClientError)):
|
||||||
return exception.code in [429, 502, 503, 504]
|
return exception.code in [429, 502, 503, 504]
|
||||||
# client errors
|
# client errors
|
||||||
if (
|
if (
|
||||||
@@ -88,9 +90,48 @@ def _is_retryable_error(exception: BaseException) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_retry_delay(exception: BaseException) -> float:
|
||||||
|
"""Extract retry delay from Gemini error response, return in seconds"""
|
||||||
|
if (
|
||||||
|
isinstance(exception, (gerrors.ClientError, gerrors.APIError))
|
||||||
|
and hasattr(exception, "details")
|
||||||
|
and isinstance(exception.details, dict)
|
||||||
|
):
|
||||||
|
# Look for retryDelay key, value pair. E.g "retryDelay": "54s"
|
||||||
|
if delay_str := exception.details.get("retryDelay"):
|
||||||
|
delay_seconds_match = re.search(r"(\d+)s", delay_str)
|
||||||
|
if delay_seconds_match:
|
||||||
|
delay_seconds = float(delay_seconds_match.group(1))
|
||||||
|
return delay_seconds
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_with_gemini_delay(min_wait=4, max_wait=120, multiplier=1, fallback_wait=None):
|
||||||
|
"""Custom wait strategy that respects Gemini's retryDelay if present"""
|
||||||
|
|
||||||
|
def wait_func(retry_state: RetryCallState) -> float:
|
||||||
|
# Use backoff time if last exception suggests a retry delay
|
||||||
|
if retry_state.outcome and retry_state.outcome.failed:
|
||||||
|
exception = retry_state.outcome.exception()
|
||||||
|
gemini_delay = _extract_retry_delay(exception)
|
||||||
|
if gemini_delay:
|
||||||
|
# Use the Gemini-suggested delay, but cap it at max_wait
|
||||||
|
suggested_delay = min(gemini_delay, max_wait)
|
||||||
|
logger.info(f"Using Gemini suggested retry delay: {suggested_delay} seconds")
|
||||||
|
return suggested_delay
|
||||||
|
# Else use fallback backoff if provided
|
||||||
|
if fallback_wait:
|
||||||
|
return fallback_wait(retry_state)
|
||||||
|
# Else use exponential backoff with provided parameters
|
||||||
|
else:
|
||||||
|
return wait_exponential(multiplier=multiplier, min=min_wait, max=max_wait)(retry_state)
|
||||||
|
|
||||||
|
return wait_func
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception(_is_retryable_error),
|
retry=retry_if_exception(_is_retryable_error),
|
||||||
wait=wait_random_exponential(min=1, max=10),
|
wait=_wait_with_gemini_delay(min_wait=1, max_wait=10, fallback_wait=wait_random_exponential(min=1, max=10)),
|
||||||
stop=stop_after_attempt(2),
|
stop=stop_after_attempt(2),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
@@ -169,7 +210,14 @@ def gemini_completion_with_backoff(
|
|||||||
)
|
)
|
||||||
except gerrors.ClientError as e:
|
except gerrors.ClientError as e:
|
||||||
response = None
|
response = None
|
||||||
response_text, _ = handle_gemini_response(e.args)
|
# Handle 429 rate limit errors directly
|
||||||
|
if e.code == 429:
|
||||||
|
response_text = f"My brain is exhausted. Can you please try again in a bit?"
|
||||||
|
# Log the full error details for debugging
|
||||||
|
logger.error(f"Gemini ClientError: {e.code} {e.status}. Details: {e.details}")
|
||||||
|
# Handle other errors
|
||||||
|
else:
|
||||||
|
response_text, _ = handle_gemini_response(e.args)
|
||||||
# Respond with reason for stopping
|
# Respond with reason for stopping
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"LLM Response Prevented for {model_name}: {response_text}.\n"
|
f"LLM Response Prevented for {model_name}: {response_text}.\n"
|
||||||
@@ -206,7 +254,7 @@ def gemini_completion_with_backoff(
|
|||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception(_is_retryable_error),
|
retry=retry_if_exception(_is_retryable_error),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=_wait_with_gemini_delay(multiplier=1, min_wait=4, max_wait=10),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
reraise=False,
|
reraise=False,
|
||||||
@@ -310,6 +358,13 @@ def handle_gemini_response(
|
|||||||
candidates: list[gtypes.Candidate], prompt_feedback: gtypes.GenerateContentResponsePromptFeedback = None
|
candidates: list[gtypes.Candidate], prompt_feedback: gtypes.GenerateContentResponsePromptFeedback = None
|
||||||
):
|
):
|
||||||
"""Check if Gemini response was blocked and return an explanatory error message."""
|
"""Check if Gemini response was blocked and return an explanatory error message."""
|
||||||
|
|
||||||
|
# Ensure we have a proper list of candidates
|
||||||
|
if not isinstance(candidates, list):
|
||||||
|
message = f"\nUnexpected response format. Try again."
|
||||||
|
stopped = True
|
||||||
|
return message, stopped
|
||||||
|
|
||||||
# Check if the response was blocked due to safety concerns with the prompt
|
# Check if the response was blocked due to safety concerns with the prompt
|
||||||
if len(candidates) == 0 and prompt_feedback:
|
if len(candidates) == 0 and prompt_feedback:
|
||||||
message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query."
|
message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query."
|
||||||
@@ -428,7 +483,18 @@ def format_messages_for_gemini(
|
|||||||
if len(messages) == 1:
|
if len(messages) == 1:
|
||||||
messages[0].role = "user"
|
messages[0].role = "user"
|
||||||
|
|
||||||
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]
|
# Ensure messages are properly formatted for Content creation
|
||||||
|
valid_messages = []
|
||||||
|
for message in messages:
|
||||||
|
try:
|
||||||
|
# Try create Content object to validate the structure before adding to valid messages
|
||||||
|
gtypes.Content(role=message.role, parts=message.content)
|
||||||
|
valid_messages.append(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Dropping message with invalid content structure: {e}. Message: {message}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in valid_messages]
|
||||||
return formatted_messages, system_prompt
|
return formatted_messages, system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user