From 7d8e8eb0cf831556c34c4e47ab28ccbbab03623b Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 15 Apr 2024 16:44:56 +0530 Subject: [PATCH] Use Enum to type text-to-image intent of Khoj chat response --- src/khoj/database/admin.py | 15 +++++++++++---- .../migrations/0035_convert_png_to_webp.py | 14 ++++++++------ src/khoj/routers/helpers.py | 17 +++++++++-------- src/khoj/utils/helpers.py | 14 ++++++++++++++ 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 97a0f3ed..9b82029b 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -23,6 +23,7 @@ from khoj.database.models import ( TextToImageModelConfig, UserSearchModelConfig, ) +from khoj.utils.helpers import ImageIntentType class KhojUserAdmin(UserAdmin): @@ -104,9 +105,12 @@ class ConversationAdmin(admin.ModelAdmin): log["by"] == "khoj" and log["intent"] and log["intent"]["type"] - and log["intent"]["type"] == "text-to-image" + and ( + log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value + or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value + ) ): - log["message"] = "image redacted for space" + log["message"] = "inline image redacted for space" chat_log[idx] = log modified_log["chat"] = chat_log @@ -144,9 +148,12 @@ class ConversationAdmin(admin.ModelAdmin): log["by"] == "khoj" and log["intent"] and log["intent"]["type"] - and log["intent"]["type"] == "text-to-image" + and ( + log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value + or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value + ) ): - updated_log["message"] = "image redacted for space" + updated_log["message"] = "inline image redacted for space" chat_log[idx] = updated_log return_log["chat"] = chat_log diff --git a/src/khoj/database/migrations/0035_convert_png_to_webp.py b/src/khoj/database/migrations/0035_convert_png_to_webp.py index 7d28a07d..6ffa024b 100644 --- a/src/khoj/database/migrations/0035_convert_png_to_webp.py +++ b/src/khoj/database/migrations/0035_convert_png_to_webp.py @@ -6,13 +6,15 @@ import io from django.db import migrations from PIL import Image +from khoj.utils.helpers import ImageIntentType + def convert_png_images_to_webp(apps, schema_editor): # Get the model from the versioned app registry to ensure the correct version is used Conversations = apps.get_model("database", "Conversation") for conversation in Conversations.objects.all(): for chat in conversation.conversation_log["chat"]: - if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image": + if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value: # Decode the base64 encoded PNG image decoded_image = base64.b64decode(chat["message"]) @@ -25,10 +27,10 @@ def convert_png_images_to_webp(apps, schema_editor): # Encode the WebP image back to base64 webp_image_bytes = webp_image_io.getvalue() chat["message"] = base64.b64encode(webp_image_bytes).decode() - chat["intent"]["type"] = "text-to-image-v3" + chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE_V3.value webp_image_io.close() - if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image2": + if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value: print("❗️ Please MANUALLY update PNG images created by Khoj in your AWS S3 bucket to WebP format.") # Convert PNG url to WebP url chat["message"] = chat["message"].replace(".png", ".webp") @@ -42,7 +44,7 @@ def convert_webp_images_to_png(apps, schema_editor): Conversations = apps.get_model("database", "Conversation") for conversation in Conversations.objects.all(): for chat in conversation.conversation_log["chat"]: - if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image": + if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value: # Decode the base64 encoded PNG image decoded_image = base64.b64decode(chat["message"]) @@ -55,10 +57,10 @@ def convert_webp_images_to_png(apps, schema_editor): # Encode the WebP image back to base64 webp_image_bytes = webp_image_io.getvalue() chat["message"] = base64.b64encode(webp_image_bytes).decode() - chat["intent"]["type"] = "text-to-image" + chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE.value webp_image_io.close() - if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image2": + if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value: # Convert WebP url to PNG url print("❗️ Please MANUALLY update WebP images created by Khoj in your AWS S3 bucket to PNG format.") chat["message"] = chat["message"].replace(".webp", ".png") diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 06d849ca..3c93385d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -49,6 +49,7 @@ from khoj.utils import state from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.helpers import ( ConversationCommand, + ImageIntentType, is_none_or_empty, is_valid_url, log_telemetry, @@ -520,14 +521,14 @@ async def text_to_image( image = None response = None image_url = None - intent_type = "text-to-image-v3" + intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 message = "Failed to generate image. Setup image generation on the server." - return image_url or image, status_code, message, intent_type + return image_url or image, status_code, message, intent_type.value elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: logger.info("Generating image with OpenAI") text2image_model = text_to_image_config.model_name @@ -572,24 +573,24 @@ async def text_to_image( with timer("Upload image to S3", logger): image_url = upload_image(webp_image_bytes, user.uuid) if image_url: - intent_type = "text-to-image-v2" + intent_type = ImageIntentType.TEXT_TO_IMAGE2 else: - intent_type = "text-to-image-v3" + intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 image = base64.b64encode(webp_image_bytes).decode("utf-8") - return image_url or image, status_code, improved_image_prompt, intent_type + return image_url or image, status_code, improved_image_prompt, intent_type.value except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: if "content_policy_violation" in e.message: logger.error(f"Image Generation blocked by OpenAI: {e}") status_code = e.status_code # type: ignore message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore - return image_url or image, status_code, message, intent_type + return image_url or image, status_code, message, intent_type.value else: logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore status_code = e.status_code # type: ignore - return image_url or image, status_code, message, intent_type - return image_url or image, status_code, response, intent_type + return image_url or image, status_code, message, intent_type.value + return image_url or image, status_code, response, intent_type.value class ApiUserRateLimiter: diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 04974b7d..e621f53e 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -329,6 +329,20 @@ mode_descriptions_for_llm = { } +class ImageIntentType(Enum): + """ + Chat message intent by Khoj for image responses. + Marks the schema used to reference image in chat messages + """ + + # Images as Inline PNG + TEXT_TO_IMAGE = "text-to-image" + # Images as URLs + TEXT_TO_IMAGE2 = "text-to-image2" + # Images as Inline WebP + TEXT_TO_IMAGE_V3 = "text-to-image-v3" + + def generate_random_name(): # List of adjectives and nouns to choose from adjectives = [