Support using image generation models like Flux via Replicate

Enables using any image generation model on Replicate's Predictions
API endpoints.

The server admin just needs to add text-to-image model on the
server/admin panel in organization/model_name format and input their
Replicate API key with it

Create db migration (including merge)
This commit is contained in:
Debanjum Singh Solanky
2024-09-12 00:29:04 -07:00
parent 1d512b4986
commit 1b82aea753
4 changed files with 94 additions and 6 deletions

View File

@@ -0,0 +1,21 @@
# Generated by Django 5.0.7 on 2024-09-12 05:43
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0060_merge_20240905_1828"),
]
operations = [
migrations.AlterField(
model_name="texttoimagemodelconfig",
name="model_type",
field=models.CharField(
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai"), ("replicate", "Replicate")],
default="openai",
max_length=200,
),
),
]

View File

@@ -0,0 +1,14 @@
# Generated by Django 5.0.8 on 2024-09-13 02:22
from typing import List
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0061_alter_chatmodeloptions_model_type"),
("database", "0061_alter_texttoimagemodelconfig_model_type"),
]
operations: List[str] = []

View File

@@ -280,6 +280,7 @@ class TextToImageModelConfig(BaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
OPENAI = "openai" OPENAI = "openai"
STABILITYAI = "stability-ai" STABILITYAI = "stability-ai"
REPLICATE = "replicate"
model_name = models.CharField(max_length=200, default="dall-e-3") model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)

View File

@@ -7,6 +7,7 @@ import logging
import math import math
import os import os
import re import re
import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from enum import Enum from enum import Enum
@@ -568,7 +569,7 @@ async def generate_better_image_prompt(
references=user_references, references=user_references,
online_results=simplified_online_results, online_results=simplified_online_results,
) )
elif model_type == TextToImageModelConfig.ModelType.STABILITYAI: elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
image_prompt = prompts.image_generation_improve_prompt_sd.format( image_prompt = prompts.image_generation_improve_prompt_sd.format(
query=q, query=q,
chat_history=conversation_history, chat_history=conversation_history,
@@ -991,7 +992,8 @@ async def text_to_image(
extra_headers=auth_header, extra_headers=auth_header,
) )
image = response.data[0].b64_json image = response.data[0].b64_json
decoded_image = base64.b64decode(image) # Decode base64 png and convert it to webp for faster loading
webp_image_bytes = convert_image_to_webp(base64.b64decode(image))
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message: if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}") logger.error(f"Image Generation blocked by OpenAI: {e}")
@@ -1021,7 +1023,8 @@ async def text_to_image(
"aspect_ratio": "1:1", "aspect_ratio": "1:1",
}, },
) )
decoded_image = response.content # Convert png to webp for faster loading
webp_image_bytes = convert_image_to_webp(response.content)
except requests.RequestException as e: except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True) logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with Stability AI error: {e}" message = f"Image generation failed with Stability AI error: {e}"
@@ -1029,9 +1032,58 @@ async def text_to_image(
yield image_url or image, status_code, message, intent_type.value yield image_url or image, status_code, message, intent_type.value
return return
with timer("Convert image to webp", logger): elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
# Convert png to webp for faster loading with timer("Generate image using Replicate", logger):
webp_image_bytes = convert_image_to_webp(decoded_image) try:
# Create image generation task on Replicate
create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
headers = {
"Authorization": f"Bearer {text_to_image_config.api_key}",
"Content-Type": "application/json",
}
json = {
"input": {
"prompt": improved_image_prompt,
"num_outputs": 1,
"aspect_ratio": "1:1",
"output_format": "webp",
"output_quality": 100,
}
}
create_prediction = requests.post(create_prediction_url, headers=headers, json=json).json()
# Get status of image generation task
get_prediction_url = create_prediction["urls"]["get"]
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count = 1
# Poll the image generation task for completion status
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
time.sleep(2)
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count += 1
# Raise exception if the image generation task fails
if status != "succeeded":
if retry_count >= 10:
raise requests.RequestException("Image generation timed out")
raise requests.RequestException(f"Image generation failed with status: {status}")
# Get the generated image
image_url = (
get_prediction["output"][0]
if isinstance(get_prediction["output"], list)
else get_prediction["output"]
)
webp_image_bytes = io.BytesIO(requests.get(image_url).content).getvalue()
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation for {text2image_model} failed with Replicate API error: {e}"
status_code = 500
yield image_url or image, status_code, message, intent_type.value
return
with timer("Upload image to S3", logger): with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid) image_url = upload_image(webp_image_bytes, user.uuid)