From 1b82aea753096a66e59799b41cba53f552fcc9f1 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 12 Sep 2024 00:29:04 -0700 Subject: [PATCH] Support using image generation models like Flux via Replicate Enables using any image generation model on Replicate's Predictions API endpoints. The server admin just needs to add text-to-image model on the server/admin panel in organization/model_name format and input their Replicate API key with it Create db migration (including merge) --- ...alter_texttoimagemodelconfig_model_type.py | 21 ++++++ .../migrations/0062_merge_20240913_0222.py | 14 ++++ src/khoj/database/models/__init__.py | 1 + src/khoj/routers/helpers.py | 64 +++++++++++++++++-- 4 files changed, 94 insertions(+), 6 deletions(-) create mode 100644 src/khoj/database/migrations/0061_alter_texttoimagemodelconfig_model_type.py create mode 100644 src/khoj/database/migrations/0062_merge_20240913_0222.py diff --git a/src/khoj/database/migrations/0061_alter_texttoimagemodelconfig_model_type.py b/src/khoj/database/migrations/0061_alter_texttoimagemodelconfig_model_type.py new file mode 100644 index 00000000..4431a9d8 --- /dev/null +++ b/src/khoj/database/migrations/0061_alter_texttoimagemodelconfig_model_type.py @@ -0,0 +1,21 @@ +# Generated by Django 5.0.7 on 2024-09-12 05:43 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0060_merge_20240905_1828"), + ] + + operations = [ + migrations.AlterField( + model_name="texttoimagemodelconfig", + name="model_type", + field=models.CharField( + choices=[("openai", "Openai"), ("stability-ai", "Stabilityai"), ("replicate", "Replicate")], + default="openai", + max_length=200, + ), + ), + ] diff --git a/src/khoj/database/migrations/0062_merge_20240913_0222.py b/src/khoj/database/migrations/0062_merge_20240913_0222.py new file mode 100644 index 00000000..51175c50 --- /dev/null +++ b/src/khoj/database/migrations/0062_merge_20240913_0222.py @@ -0,0 +1,14 @@ +# Generated by Django 5.0.8 on 2024-09-13 02:22 + +from typing import List + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0061_alter_chatmodeloptions_model_type"), + ("database", "0061_alter_texttoimagemodelconfig_model_type"), + ] + + operations: List[str] = [] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 80769de8..4029cf3c 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -280,6 +280,7 @@ class TextToImageModelConfig(BaseModel): class ModelType(models.TextChoices): OPENAI = "openai" STABILITYAI = "stability-ai" + REPLICATE = "replicate" model_name = models.CharField(max_length=200, default="dall-e-3") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index f1b8ddd6..5ccaa4dd 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -7,6 +7,7 @@ import logging import math import os import re +import time from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone from enum import Enum @@ -568,7 +569,7 @@ async def generate_better_image_prompt( references=user_references, online_results=simplified_online_results, ) - elif model_type == TextToImageModelConfig.ModelType.STABILITYAI: + elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]: image_prompt = prompts.image_generation_improve_prompt_sd.format( query=q, chat_history=conversation_history, @@ -991,7 +992,8 @@ async def text_to_image( extra_headers=auth_header, ) image = response.data[0].b64_json - decoded_image = base64.b64decode(image) + # Decode base64 png and convert it to webp for faster loading + webp_image_bytes = convert_image_to_webp(base64.b64decode(image)) except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: if "content_policy_violation" in e.message: logger.error(f"Image Generation blocked by OpenAI: {e}") @@ -1021,7 +1023,8 @@ async def text_to_image( "aspect_ratio": "1:1", }, ) - decoded_image = response.content + # Convert png to webp for faster loading + webp_image_bytes = convert_image_to_webp(response.content) except requests.RequestException as e: logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed with Stability AI error: {e}" @@ -1029,9 +1032,58 @@ async def text_to_image( yield image_url or image, status_code, message, intent_type.value return - with timer("Convert image to webp", logger): - # Convert png to webp for faster loading - webp_image_bytes = convert_image_to_webp(decoded_image) + elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE: + with timer("Generate image using Replicate", logger): + try: + # Create image generation task on Replicate + create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions" + headers = { + "Authorization": f"Bearer {text_to_image_config.api_key}", + "Content-Type": "application/json", + } + json = { + "input": { + "prompt": improved_image_prompt, + "num_outputs": 1, + "aspect_ratio": "1:1", + "output_format": "webp", + "output_quality": 100, + } + } + create_prediction = requests.post(create_prediction_url, headers=headers, json=json).json() + + # Get status of image generation task + get_prediction_url = create_prediction["urls"]["get"] + get_prediction = requests.get(get_prediction_url, headers=headers).json() + status = get_prediction["status"] + retry_count = 1 + + # Poll the image generation task for completion status + while status not in ["succeeded", "failed", "canceled"] and retry_count < 20: + time.sleep(2) + get_prediction = requests.get(get_prediction_url, headers=headers).json() + status = get_prediction["status"] + retry_count += 1 + + # Raise exception if the image generation task fails + if status != "succeeded": + if retry_count >= 10: + raise requests.RequestException("Image generation timed out") + raise requests.RequestException(f"Image generation failed with status: {status}") + + # Get the generated image + image_url = ( + get_prediction["output"][0] + if isinstance(get_prediction["output"], list) + else get_prediction["output"] + ) + webp_image_bytes = io.BytesIO(requests.get(image_url).content).getvalue() + except requests.RequestException as e: + logger.error(f"Image Generation failed with {e}", exc_info=True) + message = f"Image generation for {text2image_model} failed with Replicate API error: {e}" + status_code = 500 + yield image_url or image, status_code, message, intent_type.value + return with timer("Upload image to S3", logger): image_url = upload_image(webp_image_bytes, user.uuid)