diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index d85cd09e..9f2be46c 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -5,6 +5,7 @@ from copy import deepcopy from time import perf_counter from typing import AsyncGenerator, AsyncIterator, Dict +import httpx from google import genai from google.genai import errors as gerrors from google.genai import types as gtypes @@ -13,6 +14,7 @@ from pydantic import BaseModel from tenacity import ( before_sleep_log, retry, + retry_if_exception, stop_after_attempt, wait_exponential, wait_random_exponential, @@ -61,7 +63,19 @@ SAFETY_SETTINGS = [ ] +def _is_retryable_error(exception: BaseException) -> bool: + """Check if the exception is a retryable error""" + # server errors + if isinstance(exception, gerrors.APIError): + return exception.code in [429, 502, 503, 504] + # client errors + if isinstance(exception, httpx.TimeoutException) or isinstance(exception, httpx.NetworkError): + return True + return False + + @retry( + retry=retry_if_exception(_is_retryable_error), wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(2), before_sleep=before_sleep_log(logger, logging.DEBUG), @@ -104,6 +118,7 @@ def gemini_completion_with_backoff( response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain", response_schema=response_schema, seed=seed, + http_options=gtypes.HttpOptions(client_args={"timeout": httpx.Timeout(30.0, read=60.0)}), ) try: @@ -137,8 +152,9 @@ def gemini_completion_with_backoff( @retry( + retry=retry_if_exception(_is_retryable_error), wait=wait_exponential(multiplier=1, min=4, max=10), - stop=stop_after_attempt(2), + stop=stop_after_attempt(3), before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) @@ -174,6 +190,7 @@ async def gemini_chat_completion_with_backoff( stop_sequences=["Notes:\n["], safety_settings=SAFETY_SETTINGS, seed=seed, + http_options=gtypes.HttpOptions(async_client_args={"timeout": httpx.Timeout(30.0, read=60.0)}), ) aggregated_response = ""