From ca9109455bc0274928cc6d0f4330c302a2467ffd Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 20 Jun 2025 00:37:55 -0700 Subject: [PATCH] Retry on intermitted image generation failure for resilient generation --- src/khoj/processor/image/generate.py | 43 ++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 124a9ce8..164e3991 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -8,6 +8,14 @@ import openai import requests from google import genai from google.genai import types as gtypes +from tenacity import ( + retry, + retry_if_exception, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) +from tenacity.before_sleep import before_sleep_log from khoj.database.adapters import ConversationAdapters from khoj.database.models import ( @@ -16,6 +24,7 @@ from khoj.database.models import ( KhojUser, TextToImageModelConfig, ) +from khoj.processor.conversation.google.utils import _is_retryable_error from khoj.routers.helpers import ChatEvent, generate_better_image_prompt from khoj.routers.storage import upload_generated_image_to_bucket from khoj.utils import state @@ -131,6 +140,19 @@ async def text_to_image( yield image_url or image, status_code, image_prompt +@retry( + retry=( + retry_if_exception_type(openai.APITimeoutError) + | retry_if_exception_type(openai.APIError) + | retry_if_exception_type(openai.APIConnectionError) + | retry_if_exception_type(openai.RateLimitError) + | retry_if_exception_type(openai.APIStatusError) + ), + wait=wait_random_exponential(min=1, max=10), + stop=stop_after_attempt(3), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) def generate_image_with_openai( improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str ): @@ -163,6 +185,13 @@ def generate_image_with_openai( return convert_image_to_webp(base64.b64decode(image)) +@retry( + retry=retry_if_exception_type(requests.RequestException), + wait=wait_random_exponential(min=1, max=10), + stop=stop_after_attempt(3), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) def generate_image_with_stability( improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str ): @@ -185,6 +214,13 @@ def generate_image_with_stability( return convert_image_to_webp(response.content) +@retry( + retry=retry_if_exception_type(requests.RequestException), + wait=wait_random_exponential(min=1, max=10), + stop=stop_after_attempt(3), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) def generate_image_with_replicate( improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str ): @@ -232,6 +268,13 @@ def generate_image_with_replicate( return io.BytesIO(requests.get(image_url).content).getvalue() +@retry( + retry=retry_if_exception(_is_retryable_error), + wait=wait_random_exponential(min=1, max=10), + stop=stop_after_attempt(3), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) def generate_image_with_google( improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str ):