Improve retry, increase timeouts of gemini api calls

- Catch specific retryable exceptions for retry
- Increase httpx timeout from default of 5s to 20s
This commit is contained in:
Debanjum
2025-05-17 02:40:47 -07:00
parent 20f08ca564
commit 10a5d68a2c

View File

@@ -5,6 +5,7 @@ from copy import deepcopy
from time import perf_counter from time import perf_counter
from typing import AsyncGenerator, AsyncIterator, Dict from typing import AsyncGenerator, AsyncIterator, Dict
import httpx
from google import genai from google import genai
from google.genai import errors as gerrors from google.genai import errors as gerrors
from google.genai import types as gtypes from google.genai import types as gtypes
@@ -13,6 +14,7 @@ from pydantic import BaseModel
from tenacity import ( from tenacity import (
before_sleep_log, before_sleep_log,
retry, retry,
retry_if_exception,
stop_after_attempt, stop_after_attempt,
wait_exponential, wait_exponential,
wait_random_exponential, wait_random_exponential,
@@ -61,7 +63,19 @@ SAFETY_SETTINGS = [
] ]
def _is_retryable_error(exception: BaseException) -> bool:
"""Check if the exception is a retryable error"""
# server errors
if isinstance(exception, gerrors.APIError):
return exception.code in [429, 502, 503, 504]
# client errors
if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError):
return True
return False
@retry( @retry(
retry=retry_if_exception(_is_retryable_error),
wait=wait_random_exponential(min=1, max=10), wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2), stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
@@ -104,6 +118,7 @@ def gemini_completion_with_backoff(
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain", response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
response_schema=response_schema, response_schema=response_schema,
seed=seed, seed=seed,
http_options=gtypes.HttpOptions(client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
) )
try: try:
@@ -137,8 +152,9 @@ def gemini_completion_with_backoff(
@retry( @retry(
retry=retry_if_exception(_is_retryable_error),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(2), stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
@@ -174,6 +190,7 @@ async def gemini_chat_completion_with_backoff(
stop_sequences=["Notes:\n["], stop_sequences=["Notes:\n["],
safety_settings=SAFETY_SETTINGS, safety_settings=SAFETY_SETTINGS,
seed=seed, seed=seed,
http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}),
) )
aggregated_response = "" aggregated_response = ""