Use Enum to type text-to-image intent of Khoj chat response

This commit is contained in:
Debanjum Singh Solanky
2024-04-15 16:44:56 +05:30
parent 128829c477
commit 7d8e8eb0cf
4 changed files with 42 additions and 18 deletions

View File

@@ -23,6 +23,7 @@ from khoj.database.models import (
TextToImageModelConfig,
UserSearchModelConfig,
)
from khoj.utils.helpers import ImageIntentType
class KhojUserAdmin(UserAdmin):
@@ -104,9 +105,12 @@ class ConversationAdmin(admin.ModelAdmin):
log["by"] == "khoj"
and log["intent"]
and log["intent"]["type"]
and log["intent"]["type"] == "text-to-image"
and (
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
)
):
log["message"] = "image redacted for space"
log["message"] = "inline image redacted for space"
chat_log[idx] = log
modified_log["chat"] = chat_log
@@ -144,9 +148,12 @@ class ConversationAdmin(admin.ModelAdmin):
log["by"] == "khoj"
and log["intent"]
and log["intent"]["type"]
and log["intent"]["type"] == "text-to-image"
and (
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
)
):
updated_log["message"] = "image redacted for space"
updated_log["message"] = "inline image redacted for space"
chat_log[idx] = updated_log
return_log["chat"] = chat_log

View File

@@ -6,13 +6,15 @@ import io
from django.db import migrations
from PIL import Image
from khoj.utils.helpers import ImageIntentType
def convert_png_images_to_webp(apps, schema_editor):
# Get the model from the versioned app registry to ensure the correct version is used
Conversations = apps.get_model("database", "Conversation")
for conversation in Conversations.objects.all():
for chat in conversation.conversation_log["chat"]:
if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image":
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value:
# Decode the base64 encoded PNG image
decoded_image = base64.b64decode(chat["message"])
@@ -25,10 +27,10 @@ def convert_png_images_to_webp(apps, schema_editor):
# Encode the WebP image back to base64
webp_image_bytes = webp_image_io.getvalue()
chat["message"] = base64.b64encode(webp_image_bytes).decode()
chat["intent"]["type"] = "text-to-image-v3"
chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE_V3.value
webp_image_io.close()
if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image2":
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value:
print("❗️ Please MANUALLY update PNG images created by Khoj in your AWS S3 bucket to WebP format.")
# Convert PNG url to WebP url
chat["message"] = chat["message"].replace(".png", ".webp")
@@ -42,7 +44,7 @@ def convert_webp_images_to_png(apps, schema_editor):
Conversations = apps.get_model("database", "Conversation")
for conversation in Conversations.objects.all():
for chat in conversation.conversation_log["chat"]:
if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image":
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value:
# Decode the base64 encoded PNG image
decoded_image = base64.b64decode(chat["message"])
@@ -55,10 +57,10 @@ def convert_webp_images_to_png(apps, schema_editor):
# Encode the WebP image back to base64
webp_image_bytes = webp_image_io.getvalue()
chat["message"] = base64.b64encode(webp_image_bytes).decode()
chat["intent"]["type"] = "text-to-image"
chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE.value
webp_image_io.close()
if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image2":
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value:
# Convert WebP url to PNG url
print("❗️ Please MANUALLY update WebP images created by Khoj in your AWS S3 bucket to PNG format.")
chat["message"] = chat["message"].replace(".webp", ".png")

View File

@@ -49,6 +49,7 @@ from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import (
ConversationCommand,
ImageIntentType,
is_none_or_empty,
is_valid_url,
log_telemetry,
@@ -520,14 +521,14 @@ async def text_to_image(
image = None
response = None
image_url = None
intent_type = "text-to-image-v3"
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
return image_url or image, status_code, message, intent_type
return image_url or image, status_code, message, intent_type.value
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
logger.info("Generating image with OpenAI")
text2image_model = text_to_image_config.model_name
@@ -572,24 +573,24 @@ async def text_to_image(
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = "text-to-image-v2"
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = "text-to-image-v3"
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")
return image_url or image, status_code, improved_image_prompt, intent_type
return image_url or image, status_code, improved_image_prompt, intent_type.value
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image_url or image, status_code, message, intent_type
return image_url or image, status_code, message, intent_type.value
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type
return image_url or image, status_code, response, intent_type
return image_url or image, status_code, message, intent_type.value
return image_url or image, status_code, response, intent_type.value
class ApiUserRateLimiter:

View File

@@ -329,6 +329,20 @@ mode_descriptions_for_llm = {
}
class ImageIntentType(Enum):
"""
Chat message intent by Khoj for image responses.
Marks the schema used to reference image in chat messages
"""
# Images as Inline PNG
TEXT_TO_IMAGE = "text-to-image"
# Images as URLs
TEXT_TO_IMAGE2 = "text-to-image2"
# Images as Inline WebP
TEXT_TO_IMAGE_V3 = "text-to-image-v3"
def generate_random_name():
# List of adjectives and nouns to choose from
adjectives = [