Retry on intermitted image generation failure for resilient generation

This commit is contained in:
Debanjum
2025-06-20 00:37:55 -07:00
parent 4448ab665c
commit ca9109455b

View File

@@ -8,6 +8,14 @@ import openai
import requests import requests
from google import genai from google import genai
from google.genai import types as gtypes from google.genai import types as gtypes
from tenacity import (
retry,
retry_if_exception,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from tenacity.before_sleep import before_sleep_log
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ( from khoj.database.models import (
@@ -16,6 +24,7 @@ from khoj.database.models import (
KhojUser, KhojUser,
TextToImageModelConfig, TextToImageModelConfig,
) )
from khoj.processor.conversation.google.utils import _is_retryable_error
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_generated_image_to_bucket from khoj.routers.storage import upload_generated_image_to_bucket
from khoj.utils import state from khoj.utils import state
@@ -131,6 +140,19 @@ async def text_to_image(
yield image_url or image, status_code, image_prompt yield image_url or image, status_code, image_prompt
@retry(
retry=(
retry_if_exception_type(openai.APITimeoutError)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.APIStatusError)
),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def generate_image_with_openai( def generate_image_with_openai(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
): ):
@@ -163,6 +185,13 @@ def generate_image_with_openai(
return convert_image_to_webp(base64.b64decode(image)) return convert_image_to_webp(base64.b64decode(image))
@retry(
retry=retry_if_exception_type(requests.RequestException),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def generate_image_with_stability( def generate_image_with_stability(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
): ):
@@ -185,6 +214,13 @@ def generate_image_with_stability(
return convert_image_to_webp(response.content) return convert_image_to_webp(response.content)
@retry(
retry=retry_if_exception_type(requests.RequestException),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def generate_image_with_replicate( def generate_image_with_replicate(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
): ):
@@ -232,6 +268,13 @@ def generate_image_with_replicate(
return io.BytesIO(requests.get(image_url).content).getvalue() return io.BytesIO(requests.get(image_url).content).getvalue()
@retry(
retry=retry_if_exception(_is_retryable_error),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def generate_image_with_google( def generate_image_with_google(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
): ):