mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Support using image generation models like Flux via Replicate
Enables using any image generation model on Replicate's Predictions API endpoints. The server admin just needs to add text-to-image model on the server/admin panel in organization/model_name format and input their Replicate API key with it Create db migration (including merge)
This commit is contained in:
@@ -0,0 +1,21 @@
|
|||||||
|
# Generated by Django 5.0.7 on 2024-09-12 05:43
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0060_merge_20240905_1828"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="texttoimagemodelconfig",
|
||||||
|
name="model_type",
|
||||||
|
field=models.CharField(
|
||||||
|
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai"), ("replicate", "Replicate")],
|
||||||
|
default="openai",
|
||||||
|
max_length=200,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
14
src/khoj/database/migrations/0062_merge_20240913_0222.py
Normal file
14
src/khoj/database/migrations/0062_merge_20240913_0222.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# Generated by Django 5.0.8 on 2024-09-13 02:22
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0061_alter_chatmodeloptions_model_type"),
|
||||||
|
("database", "0061_alter_texttoimagemodelconfig_model_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations: List[str] = []
|
||||||
@@ -280,6 +280,7 @@ class TextToImageModelConfig(BaseModel):
|
|||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
STABILITYAI = "stability-ai"
|
STABILITYAI = "stability-ai"
|
||||||
|
REPLICATE = "replicate"
|
||||||
|
|
||||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -568,7 +569,7 @@ async def generate_better_image_prompt(
|
|||||||
references=user_references,
|
references=user_references,
|
||||||
online_results=simplified_online_results,
|
online_results=simplified_online_results,
|
||||||
)
|
)
|
||||||
elif model_type == TextToImageModelConfig.ModelType.STABILITYAI:
|
elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
|
||||||
image_prompt = prompts.image_generation_improve_prompt_sd.format(
|
image_prompt = prompts.image_generation_improve_prompt_sd.format(
|
||||||
query=q,
|
query=q,
|
||||||
chat_history=conversation_history,
|
chat_history=conversation_history,
|
||||||
@@ -991,7 +992,8 @@ async def text_to_image(
|
|||||||
extra_headers=auth_header,
|
extra_headers=auth_header,
|
||||||
)
|
)
|
||||||
image = response.data[0].b64_json
|
image = response.data[0].b64_json
|
||||||
decoded_image = base64.b64decode(image)
|
# Decode base64 png and convert it to webp for faster loading
|
||||||
|
webp_image_bytes = convert_image_to_webp(base64.b64decode(image))
|
||||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||||
if "content_policy_violation" in e.message:
|
if "content_policy_violation" in e.message:
|
||||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||||
@@ -1021,7 +1023,8 @@ async def text_to_image(
|
|||||||
"aspect_ratio": "1:1",
|
"aspect_ratio": "1:1",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
decoded_image = response.content
|
# Convert png to webp for faster loading
|
||||||
|
webp_image_bytes = convert_image_to_webp(response.content)
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation failed with Stability AI error: {e}"
|
message = f"Image generation failed with Stability AI error: {e}"
|
||||||
@@ -1029,9 +1032,58 @@ async def text_to_image(
|
|||||||
yield image_url or image, status_code, message, intent_type.value
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
return
|
return
|
||||||
|
|
||||||
with timer("Convert image to webp", logger):
|
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
|
||||||
# Convert png to webp for faster loading
|
with timer("Generate image using Replicate", logger):
|
||||||
webp_image_bytes = convert_image_to_webp(decoded_image)
|
try:
|
||||||
|
# Create image generation task on Replicate
|
||||||
|
create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {text_to_image_config.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
json = {
|
||||||
|
"input": {
|
||||||
|
"prompt": improved_image_prompt,
|
||||||
|
"num_outputs": 1,
|
||||||
|
"aspect_ratio": "1:1",
|
||||||
|
"output_format": "webp",
|
||||||
|
"output_quality": 100,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
create_prediction = requests.post(create_prediction_url, headers=headers, json=json).json()
|
||||||
|
|
||||||
|
# Get status of image generation task
|
||||||
|
get_prediction_url = create_prediction["urls"]["get"]
|
||||||
|
get_prediction = requests.get(get_prediction_url, headers=headers).json()
|
||||||
|
status = get_prediction["status"]
|
||||||
|
retry_count = 1
|
||||||
|
|
||||||
|
# Poll the image generation task for completion status
|
||||||
|
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
|
||||||
|
time.sleep(2)
|
||||||
|
get_prediction = requests.get(get_prediction_url, headers=headers).json()
|
||||||
|
status = get_prediction["status"]
|
||||||
|
retry_count += 1
|
||||||
|
|
||||||
|
# Raise exception if the image generation task fails
|
||||||
|
if status != "succeeded":
|
||||||
|
if retry_count >= 10:
|
||||||
|
raise requests.RequestException("Image generation timed out")
|
||||||
|
raise requests.RequestException(f"Image generation failed with status: {status}")
|
||||||
|
|
||||||
|
# Get the generated image
|
||||||
|
image_url = (
|
||||||
|
get_prediction["output"][0]
|
||||||
|
if isinstance(get_prediction["output"], list)
|
||||||
|
else get_prediction["output"]
|
||||||
|
)
|
||||||
|
webp_image_bytes = io.BytesIO(requests.get(image_url).content).getvalue()
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
|
message = f"Image generation for {text2image_model} failed with Replicate API error: {e}"
|
||||||
|
status_code = 500
|
||||||
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
|
return
|
||||||
|
|
||||||
with timer("Upload image to S3", logger):
|
with timer("Upload image to S3", logger):
|
||||||
image_url = upload_image(webp_image_bytes, user.uuid)
|
image_url = upload_image(webp_image_bytes, user.uuid)
|
||||||
|
|||||||
Reference in New Issue
Block a user