Improve retry, increase timeouts of gemini api calls

- Catch specific retryable exceptions for retry
- Increase httpx timeout from default of 5s to 20s
This commit is contained in:
Debanjum
2025-05-17 02:40:47 -07:00
parent 20f08ca564
commit 10a5d68a2c

View File

@@ -5,6 +5,7 @@ from copy import deepcopy
from time import perf_counter
from typing import AsyncGenerator, AsyncIterator, Dict
import httpx
from google import genai
from google.genai import errors as gerrors
from google.genai import types as gtypes
@@ -13,6 +14,7 @@ from pydantic import BaseModel
from tenacity import (
before_sleep_log,
retry,
retry_if_exception,
stop_after_attempt,
wait_exponential,
wait_random_exponential,
@@ -61,7 +63,19 @@ SAFETY_SETTINGS = [
]
def _is_retryable_error(exception: BaseException) -> bool:
"""Check if the exception is a retryable error"""
# server errors
if isinstance(exception, gerrors.APIError):
return exception.code in [429, 502, 503, 504]
# client errors
if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError):
return True
return False
@retry(
retry=retry_if_exception(_is_retryable_error),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
@@ -104,6 +118,7 @@ def gemini_completion_with_backoff(
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
response_schema=response_schema,
seed=seed,
http_options=gtypes.HttpOptions(client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
)
try:
@@ -137,8 +152,9 @@ def gemini_completion_with_backoff(
@retry(
retry=retry_if_exception(_is_retryable_error),
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(2),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
@@ -174,6 +190,7 @@ async def gemini_chat_completion_with_backoff(
stop_sequences=["Notes:\n["],
safety_settings=SAFETY_SETTINGS,
seed=seed,
http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
)
aggregated_response = ""