Support using image generation models like Flux via Replicate (#909)

- Support using image generation models like Flux via Replicate
- Modularize the image generation code
- Make generate better image prompt chat actor add composition details
- Generate vivid images with DALLE-3
This commit is contained in:
Debanjum
2024-09-12 20:19:46 -07:00
committed by GitHub
7 changed files with 255 additions and 138 deletions

View File

@@ -0,0 +1,21 @@
# Generated by Django 5.0.7 on 2024-09-12 05:43
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0060_merge_20240905_1828"),
]
operations = [
migrations.AlterField(
model_name="texttoimagemodelconfig",
name="model_type",
field=models.CharField(
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai"), ("replicate", "Replicate")],
default="openai",
max_length=200,
),
),
]

View File

@@ -0,0 +1,14 @@
# Generated by Django 5.0.8 on 2024-09-13 02:22
from typing import List
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0061_alter_chatmodeloptions_model_type"),
("database", "0061_alter_texttoimagemodelconfig_model_type"),
]
operations: List[str] = []

View File

@@ -280,6 +280,7 @@ class TextToImageModelConfig(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
STABILITYAI = "stability-ai"
REPLICATE = "replicate"
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)

View File

@@ -128,8 +128,8 @@ User's Notes:
## --
image_generation_improve_prompt_base = """
You are a talented creator with the ability to describe images to compose in vivid, fine detail.
Use the provided context and user prompt to generate a more detailed prompt to create an image:
You are a talented media artist with the ability to describe images to compose in professional, fine detail.
Generate a vivid description of the image to be rendered using the provided context and user prompt below:
Today's Date: {current_date}
User's Location: {location}
@@ -145,10 +145,10 @@ Conversation Log:
User Prompt: "{query}"
Now generate an improved prompt describing the image to generate in vivid, fine detail.
Now generate an professional description of the image to generate in vivid, fine detail.
- Use today's date, user's location, user's notes and online references to weave in any context that will improve the image generation.
- Retain any important information and follow any instructions in the conversation log or user prompt.
- Add specific, fine position details to compose the image.
- Add specific, fine position details. Mention painting style, camera parameters to compose the image.
- Ensure your improved prompt is in prose format."""
image_generation_improve_prompt_dalle = PromptTemplate.from_template(

View File

@@ -0,0 +1,212 @@
import base64
import io
import logging
import time
from typing import Any, Callable, Dict, List, Optional
import openai
import requests
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser, TextToImageModelConfig
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_image
from khoj.utils import state
from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
async def text_to_image(
message: str,
user: KhojUser,
conversation_log: dict,
location_data: LocationData,
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
):
status_code = 200
image = None
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
yield image_url or image, status_code, message, intent_type.value
return
text2image_model = text_to_image_config.model_name
chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
if send_status_func:
async for event in send_status_func("**Enhancing the Painting Prompt**"):
yield {ChatEvent.STATUS: event}
# Generate a better image prompt
# Use the user's message, chat history, and other context
image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
)
if send_status_func:
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
yield {ChatEvent.STATUS: event}
# Generate image using the configured model and API
with timer(f"Generate image with {text_to_image_config.model_type}", logger):
try:
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}"
status_code = 502
yield image_url or image, status_code, message, intent_type.value
return
# Decide how to store the generated image
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")
yield image_url or image, status_code, image_prompt, intent_type.value
def generate_image_with_openai(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"Generate image using OpenAI API"
# Get the API key from the user's configuration
if text_to_image_config.api_key:
api_key = text_to_image_config.api_key
elif text_to_image_config.openai_config:
api_key = text_to_image_config.openai_config.api_key
elif state.openai_client:
api_key = state.openai_client.api_key
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
# Generate image using OpenAI API
OPENAI_IMAGE_GEN_STYLE = "vivid"
response = state.openai_client.images.generate(
prompt=improved_image_prompt,
model=text2image_model,
style=OPENAI_IMAGE_GEN_STYLE,
response_format="b64_json",
extra_headers=auth_header,
)
# Extract the base64 image from the response
image = response.data[0].b64_json
# Decode base64 png and convert it to webp for faster loading
return convert_image_to_webp(base64.b64decode(image))
def generate_image_with_stability(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"Generate image using Stability AI"
# Call Stability AI API to generate image
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
files={"none": ""},
data={
"prompt": improved_image_prompt,
"model": text2image_model,
"mode": "text-to-image",
"output_format": "png",
"aspect_ratio": "1:1",
},
)
# Convert png to webp for faster loading
return convert_image_to_webp(response.content)
def generate_image_with_replicate(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"Generate image using Replicate API"
# Create image generation task on Replicate
replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
headers = {
"Authorization": f"Bearer {text_to_image_config.api_key}",
"Content-Type": "application/json",
}
json = {
"input": {
"prompt": improved_image_prompt,
"num_outputs": 1,
"aspect_ratio": "1:1",
"output_format": "webp",
"output_quality": 100,
}
}
create_prediction = requests.post(replicate_create_prediction_url, headers=headers, json=json).json()
# Get status of image generation task
get_prediction_url = create_prediction["urls"]["get"]
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count = 1
# Poll the image generation task for completion status
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
time.sleep(2)
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count += 1
# Raise exception if the image generation task fails
if status != "succeeded":
if retry_count >= 10:
raise requests.RequestException("Image generation timed out")
raise requests.RequestException(f"Image generation failed with status: {status}")
# Get the generated image
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
return io.BytesIO(requests.get(image_url).content).getvalue()

View File

@@ -26,6 +26,7 @@ from khoj.database.adapters import (
from khoj.database.models import KhojUser
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.image.generate import text_to_image
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions
@@ -44,7 +45,6 @@ from khoj.routers.helpers import (
is_query_empty,
is_ready_to_chat,
read_chat_stream,
text_to_image,
update_telemetry_state,
validate_conversation_config,
)

View File

@@ -1,7 +1,5 @@
import asyncio
import base64
import hashlib
import io
import json
import logging
import math
@@ -16,7 +14,6 @@ from typing import (
Annotated,
Any,
AsyncGenerator,
Callable,
Dict,
Iterator,
List,
@@ -24,17 +21,15 @@ from typing import (
Tuple,
Union,
)
from urllib.parse import parse_qs, urlencode, urljoin, urlparse
from urllib.parse import parse_qs, urljoin, urlparse
import cron_descriptor
import openai
import pytz
import requests
from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from PIL import Image
from starlette.authentication import has_required_scope
from starlette.requests import URL
@@ -93,7 +88,6 @@ from khoj.processor.conversation.utils import (
)
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
from khoj.routers.email import is_resend_enabled, send_task_email
from khoj.routers.storage import upload_image
from khoj.routers.twilio import is_twilio_enabled
from khoj.search_type import text_search
from khoj.utils import state
@@ -101,8 +95,6 @@ from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import (
LRU,
ConversationCommand,
ImageIntentType,
convert_image_to_webp,
is_none_or_empty,
is_valid_url,
log_telemetry,
@@ -568,7 +560,7 @@ async def generate_better_image_prompt(
references=user_references,
online_results=simplified_online_results,
)
elif model_type == TextToImageModelConfig.ModelType.STABILITYAI:
elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
image_prompt = prompts.image_generation_improve_prompt_sd.format(
query=q,
chat_history=conversation_history,
@@ -921,129 +913,6 @@ def generate_chat_response(
return chat_response, metadata
async def text_to_image(
message: str,
user: KhojUser,
conversation_log: dict,
location_data: LocationData,
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
):
status_code = 200
image = None
response = None
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
yield image_url or image, status_code, message, intent_type.value
return
text2image_model = text_to_image_config.model_name
chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"
if send_status_func:
async for event in send_status_func("**Enhancing the Painting Prompt**"):
yield {ChatEvent.STATUS: event}
improved_image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
)
if send_status_func:
async for event in send_status_func(f"**Painting to Imagine**:\n{improved_image_prompt}"):
yield {ChatEvent.STATUS: event}
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger):
if text_to_image_config.api_key:
api_key = text_to_image_config.api_key
elif text_to_image_config.openai_config:
api_key = text_to_image_config.openai_config.api_key
elif state.openai_client:
api_key = state.openai_client.api_key
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt,
model=text2image_model,
response_format="b64_json",
extra_headers=auth_header,
)
image = response.data[0].b64_json
decoded_image = base64.b64decode(image)
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
with timer("Generate image with Stability AI", logger):
try:
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
files={"none": ""},
data={
"prompt": improved_image_prompt,
"model": text2image_model,
"mode": "text-to-image",
"output_format": "png",
"aspect_ratio": "1:1",
},
)
decoded_image = response.content
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with Stability AI error: {e}"
status_code = e.status_code # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
with timer("Convert image to webp", logger):
# Convert png to webp for faster loading
webp_image_bytes = convert_image_to_webp(decoded_image)
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")
yield image_url or image, status_code, improved_image_prompt, intent_type.value
class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests