mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Extract image generation code into new image processor for modularity
This commit is contained in:
212
src/khoj/processor/image/generate.py
Normal file
212
src/khoj/processor/image/generate.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from khoj.database.adapters import ConversationAdapters
|
||||||
|
from khoj.database.models import KhojUser, TextToImageModelConfig
|
||||||
|
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
|
||||||
|
from khoj.routers.storage import upload_image
|
||||||
|
from khoj.utils import state
|
||||||
|
from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer
|
||||||
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def text_to_image(
|
||||||
|
message: str,
|
||||||
|
user: KhojUser,
|
||||||
|
conversation_log: dict,
|
||||||
|
location_data: LocationData,
|
||||||
|
references: List[Dict[str, Any]],
|
||||||
|
online_results: Dict[str, Any],
|
||||||
|
subscribed: bool = False,
|
||||||
|
send_status_func: Optional[Callable] = None,
|
||||||
|
uploaded_image_url: Optional[str] = None,
|
||||||
|
):
|
||||||
|
status_code = 200
|
||||||
|
image = None
|
||||||
|
image_url = None
|
||||||
|
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||||
|
|
||||||
|
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
|
||||||
|
if not text_to_image_config:
|
||||||
|
# If the user has not configured a text to image model, return an unsupported on server error
|
||||||
|
status_code = 501
|
||||||
|
message = "Failed to generate image. Setup image generation on the server."
|
||||||
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
|
return
|
||||||
|
|
||||||
|
text2image_model = text_to_image_config.model_name
|
||||||
|
chat_history = ""
|
||||||
|
for chat in conversation_log.get("chat", [])[-4:]:
|
||||||
|
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
||||||
|
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||||
|
chat_history += f"A: {chat['message']}\n"
|
||||||
|
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
||||||
|
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
|
||||||
|
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
|
||||||
|
|
||||||
|
if send_status_func:
|
||||||
|
async for event in send_status_func("**Enhancing the Painting Prompt**"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
|
|
||||||
|
# Generate a better image prompt
|
||||||
|
# Use the user's message, chat history, and other context
|
||||||
|
image_prompt = await generate_better_image_prompt(
|
||||||
|
message,
|
||||||
|
chat_history,
|
||||||
|
location_data=location_data,
|
||||||
|
note_references=references,
|
||||||
|
online_results=online_results,
|
||||||
|
model_type=text_to_image_config.model_type,
|
||||||
|
subscribed=subscribed,
|
||||||
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
if send_status_func:
|
||||||
|
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
|
|
||||||
|
# Generate image using the configured model and API
|
||||||
|
with timer(f"Generate image with {text_to_image_config.model_type}", logger):
|
||||||
|
try:
|
||||||
|
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
|
webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model)
|
||||||
|
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
||||||
|
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
|
||||||
|
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
|
||||||
|
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
|
||||||
|
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||||
|
if "content_policy_violation" in e.message:
|
||||||
|
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||||
|
status_code = e.status_code # type: ignore
|
||||||
|
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
||||||
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
|
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
||||||
|
status_code = e.status_code # type: ignore
|
||||||
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
|
return
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
|
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}"
|
||||||
|
status_code = 502
|
||||||
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
|
return
|
||||||
|
|
||||||
|
# Decide how to store the generated image
|
||||||
|
with timer("Upload image to S3", logger):
|
||||||
|
image_url = upload_image(webp_image_bytes, user.uuid)
|
||||||
|
if image_url:
|
||||||
|
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
||||||
|
else:
|
||||||
|
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||||
|
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
yield image_url or image, status_code, image_prompt, intent_type.value
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image_with_openai(
|
||||||
|
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
|
||||||
|
):
|
||||||
|
"Generate image using OpenAI API"
|
||||||
|
|
||||||
|
# Get the API key from the user's configuration
|
||||||
|
if text_to_image_config.api_key:
|
||||||
|
api_key = text_to_image_config.api_key
|
||||||
|
elif text_to_image_config.openai_config:
|
||||||
|
api_key = text_to_image_config.openai_config.api_key
|
||||||
|
elif state.openai_client:
|
||||||
|
api_key = state.openai_client.api_key
|
||||||
|
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||||
|
|
||||||
|
# Generate image using OpenAI API
|
||||||
|
OPENAI_IMAGE_GEN_STYLE = "vivid"
|
||||||
|
response = state.openai_client.images.generate(
|
||||||
|
prompt=improved_image_prompt,
|
||||||
|
model=text2image_model,
|
||||||
|
style=OPENAI_IMAGE_GEN_STYLE,
|
||||||
|
response_format="b64_json",
|
||||||
|
extra_headers=auth_header,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the base64 image from the response
|
||||||
|
image = response.data[0].b64_json
|
||||||
|
# Decode base64 png and convert it to webp for faster loading
|
||||||
|
return convert_image_to_webp(base64.b64decode(image))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image_with_stability(
|
||||||
|
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
|
||||||
|
):
|
||||||
|
"Generate image using Stability AI"
|
||||||
|
|
||||||
|
# Call Stability AI API to generate image
|
||||||
|
response = requests.post(
|
||||||
|
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
||||||
|
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
|
||||||
|
files={"none": ""},
|
||||||
|
data={
|
||||||
|
"prompt": improved_image_prompt,
|
||||||
|
"model": text2image_model,
|
||||||
|
"mode": "text-to-image",
|
||||||
|
"output_format": "png",
|
||||||
|
"aspect_ratio": "1:1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Convert png to webp for faster loading
|
||||||
|
return convert_image_to_webp(response.content)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image_with_replicate(
|
||||||
|
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
|
||||||
|
):
|
||||||
|
"Generate image using Replicate API"
|
||||||
|
|
||||||
|
# Create image generation task on Replicate
|
||||||
|
replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {text_to_image_config.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
json = {
|
||||||
|
"input": {
|
||||||
|
"prompt": improved_image_prompt,
|
||||||
|
"num_outputs": 1,
|
||||||
|
"aspect_ratio": "1:1",
|
||||||
|
"output_format": "webp",
|
||||||
|
"output_quality": 100,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
create_prediction = requests.post(replicate_create_prediction_url, headers=headers, json=json).json()
|
||||||
|
|
||||||
|
# Get status of image generation task
|
||||||
|
get_prediction_url = create_prediction["urls"]["get"]
|
||||||
|
get_prediction = requests.get(get_prediction_url, headers=headers).json()
|
||||||
|
status = get_prediction["status"]
|
||||||
|
retry_count = 1
|
||||||
|
|
||||||
|
# Poll the image generation task for completion status
|
||||||
|
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
|
||||||
|
time.sleep(2)
|
||||||
|
get_prediction = requests.get(get_prediction_url, headers=headers).json()
|
||||||
|
status = get_prediction["status"]
|
||||||
|
retry_count += 1
|
||||||
|
|
||||||
|
# Raise exception if the image generation task fails
|
||||||
|
if status != "succeeded":
|
||||||
|
if retry_count >= 10:
|
||||||
|
raise requests.RequestException("Image generation timed out")
|
||||||
|
raise requests.RequestException(f"Image generation failed with status: {status}")
|
||||||
|
|
||||||
|
# Get the generated image
|
||||||
|
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
|
||||||
|
return io.BytesIO(requests.get(image_url).content).getvalue()
|
||||||
@@ -26,6 +26,7 @@ from khoj.database.adapters import (
|
|||||||
from khoj.database.models import KhojUser
|
from khoj.database.models import KhojUser
|
||||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||||
from khoj.processor.conversation.utils import save_to_conversation_log
|
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||||
|
from khoj.processor.image.generate import text_to_image
|
||||||
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
||||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||||
from khoj.routers.api import extract_references_and_questions
|
from khoj.routers.api import extract_references_and_questions
|
||||||
@@ -44,7 +45,6 @@ from khoj.routers.helpers import (
|
|||||||
is_query_empty,
|
is_query_empty,
|
||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
read_chat_stream,
|
read_chat_stream,
|
||||||
text_to_image,
|
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
validate_conversation_config,
|
validate_conversation_config,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -17,7 +14,6 @@ from typing import (
|
|||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Callable,
|
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
@@ -25,17 +21,15 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from urllib.parse import parse_qs, urlencode, urljoin, urlparse
|
from urllib.parse import parse_qs, urljoin, urlparse
|
||||||
|
|
||||||
import cron_descriptor
|
import cron_descriptor
|
||||||
import openai
|
|
||||||
import pytz
|
import pytz
|
||||||
import requests
|
import requests
|
||||||
from apscheduler.job import Job
|
from apscheduler.job import Job
|
||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
from PIL import Image
|
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
from starlette.requests import URL
|
from starlette.requests import URL
|
||||||
|
|
||||||
@@ -94,7 +88,6 @@ from khoj.processor.conversation.utils import (
|
|||||||
)
|
)
|
||||||
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
|
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
|
||||||
from khoj.routers.email import is_resend_enabled, send_task_email
|
from khoj.routers.email import is_resend_enabled, send_task_email
|
||||||
from khoj.routers.storage import upload_image
|
|
||||||
from khoj.routers.twilio import is_twilio_enabled
|
from khoj.routers.twilio import is_twilio_enabled
|
||||||
from khoj.search_type import text_search
|
from khoj.search_type import text_search
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
@@ -102,8 +95,6 @@ from khoj.utils.config import OfflineChatProcessorModel
|
|||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
LRU,
|
LRU,
|
||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
ImageIntentType,
|
|
||||||
convert_image_to_webp,
|
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_valid_url,
|
is_valid_url,
|
||||||
log_telemetry,
|
log_telemetry,
|
||||||
@@ -922,181 +913,6 @@ def generate_chat_response(
|
|||||||
return chat_response, metadata
|
return chat_response, metadata
|
||||||
|
|
||||||
|
|
||||||
async def text_to_image(
|
|
||||||
message: str,
|
|
||||||
user: KhojUser,
|
|
||||||
conversation_log: dict,
|
|
||||||
location_data: LocationData,
|
|
||||||
references: List[Dict[str, Any]],
|
|
||||||
online_results: Dict[str, Any],
|
|
||||||
subscribed: bool = False,
|
|
||||||
send_status_func: Optional[Callable] = None,
|
|
||||||
uploaded_image_url: Optional[str] = None,
|
|
||||||
):
|
|
||||||
status_code = 200
|
|
||||||
image = None
|
|
||||||
response = None
|
|
||||||
image_url = None
|
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
|
||||||
|
|
||||||
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
|
|
||||||
if not text_to_image_config:
|
|
||||||
# If the user has not configured a text to image model, return an unsupported on server error
|
|
||||||
status_code = 501
|
|
||||||
message = "Failed to generate image. Setup image generation on the server."
|
|
||||||
yield image_url or image, status_code, message, intent_type.value
|
|
||||||
return
|
|
||||||
|
|
||||||
text2image_model = text_to_image_config.model_name
|
|
||||||
chat_history = ""
|
|
||||||
for chat in conversation_log.get("chat", [])[-4:]:
|
|
||||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
|
||||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
|
||||||
chat_history += f"A: {chat['message']}\n"
|
|
||||||
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
|
||||||
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
|
|
||||||
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
|
|
||||||
|
|
||||||
if send_status_func:
|
|
||||||
async for event in send_status_func("**Enhancing the Painting Prompt**"):
|
|
||||||
yield {ChatEvent.STATUS: event}
|
|
||||||
improved_image_prompt = await generate_better_image_prompt(
|
|
||||||
message,
|
|
||||||
chat_history,
|
|
||||||
location_data=location_data,
|
|
||||||
note_references=references,
|
|
||||||
online_results=online_results,
|
|
||||||
model_type=text_to_image_config.model_type,
|
|
||||||
subscribed=subscribed,
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
if send_status_func:
|
|
||||||
async for event in send_status_func(f"**Painting to Imagine**:\n{improved_image_prompt}"):
|
|
||||||
yield {ChatEvent.STATUS: event}
|
|
||||||
|
|
||||||
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
|
||||||
with timer("Generate image with OpenAI", logger):
|
|
||||||
if text_to_image_config.api_key:
|
|
||||||
api_key = text_to_image_config.api_key
|
|
||||||
elif text_to_image_config.openai_config:
|
|
||||||
api_key = text_to_image_config.openai_config.api_key
|
|
||||||
elif state.openai_client:
|
|
||||||
api_key = state.openai_client.api_key
|
|
||||||
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
|
||||||
try:
|
|
||||||
response = state.openai_client.images.generate(
|
|
||||||
prompt=improved_image_prompt,
|
|
||||||
model=text2image_model,
|
|
||||||
style="vivid",
|
|
||||||
response_format="b64_json",
|
|
||||||
extra_headers=auth_header,
|
|
||||||
)
|
|
||||||
image = response.data[0].b64_json
|
|
||||||
# Decode base64 png and convert it to webp for faster loading
|
|
||||||
webp_image_bytes = convert_image_to_webp(base64.b64decode(image))
|
|
||||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
|
||||||
if "content_policy_violation" in e.message:
|
|
||||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
|
||||||
status_code = e.status_code # type: ignore
|
|
||||||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
|
||||||
yield image_url or image, status_code, message, intent_type.value
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
|
||||||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
|
||||||
status_code = e.status_code # type: ignore
|
|
||||||
yield image_url or image, status_code, message, intent_type.value
|
|
||||||
return
|
|
||||||
|
|
||||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
|
||||||
with timer("Generate image with Stability AI", logger):
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
|
|
||||||
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
|
|
||||||
files={"none": ""},
|
|
||||||
data={
|
|
||||||
"prompt": improved_image_prompt,
|
|
||||||
"model": text2image_model,
|
|
||||||
"mode": "text-to-image",
|
|
||||||
"output_format": "png",
|
|
||||||
"aspect_ratio": "1:1",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Convert png to webp for faster loading
|
|
||||||
webp_image_bytes = convert_image_to_webp(response.content)
|
|
||||||
except requests.RequestException as e:
|
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
|
||||||
message = f"Image generation failed with Stability AI error: {e}"
|
|
||||||
status_code = e.status_code # type: ignore
|
|
||||||
yield image_url or image, status_code, message, intent_type.value
|
|
||||||
return
|
|
||||||
|
|
||||||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
|
|
||||||
with timer("Generate image using Replicate", logger):
|
|
||||||
try:
|
|
||||||
# Create image generation task on Replicate
|
|
||||||
create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {text_to_image_config.api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
json = {
|
|
||||||
"input": {
|
|
||||||
"prompt": improved_image_prompt,
|
|
||||||
"num_outputs": 1,
|
|
||||||
"aspect_ratio": "1:1",
|
|
||||||
"output_format": "webp",
|
|
||||||
"output_quality": 100,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
create_prediction = requests.post(create_prediction_url, headers=headers, json=json).json()
|
|
||||||
|
|
||||||
# Get status of image generation task
|
|
||||||
get_prediction_url = create_prediction["urls"]["get"]
|
|
||||||
get_prediction = requests.get(get_prediction_url, headers=headers).json()
|
|
||||||
status = get_prediction["status"]
|
|
||||||
retry_count = 1
|
|
||||||
|
|
||||||
# Poll the image generation task for completion status
|
|
||||||
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
|
|
||||||
time.sleep(2)
|
|
||||||
get_prediction = requests.get(get_prediction_url, headers=headers).json()
|
|
||||||
status = get_prediction["status"]
|
|
||||||
retry_count += 1
|
|
||||||
|
|
||||||
# Raise exception if the image generation task fails
|
|
||||||
if status != "succeeded":
|
|
||||||
if retry_count >= 10:
|
|
||||||
raise requests.RequestException("Image generation timed out")
|
|
||||||
raise requests.RequestException(f"Image generation failed with status: {status}")
|
|
||||||
|
|
||||||
# Get the generated image
|
|
||||||
image_url = (
|
|
||||||
get_prediction["output"][0]
|
|
||||||
if isinstance(get_prediction["output"], list)
|
|
||||||
else get_prediction["output"]
|
|
||||||
)
|
|
||||||
webp_image_bytes = io.BytesIO(requests.get(image_url).content).getvalue()
|
|
||||||
except requests.RequestException as e:
|
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
|
||||||
message = f"Image generation for {text2image_model} failed with Replicate API error: {e}"
|
|
||||||
status_code = 500
|
|
||||||
yield image_url or image, status_code, message, intent_type.value
|
|
||||||
return
|
|
||||||
|
|
||||||
with timer("Upload image to S3", logger):
|
|
||||||
image_url = upload_image(webp_image_bytes, user.uuid)
|
|
||||||
if image_url:
|
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
|
||||||
else:
|
|
||||||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
|
||||||
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
|
||||||
|
|
||||||
yield image_url or image, status_code, improved_image_prompt, intent_type.value
|
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
|
|||||||
Reference in New Issue
Block a user