mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Use Enum to type text-to-image intent of Khoj chat response
This commit is contained in:
@@ -23,6 +23,7 @@ from khoj.database.models import (
|
|||||||
TextToImageModelConfig,
|
TextToImageModelConfig,
|
||||||
UserSearchModelConfig,
|
UserSearchModelConfig,
|
||||||
)
|
)
|
||||||
|
from khoj.utils.helpers import ImageIntentType
|
||||||
|
|
||||||
|
|
||||||
class KhojUserAdmin(UserAdmin):
|
class KhojUserAdmin(UserAdmin):
|
||||||
@@ -104,9 +105,12 @@ class ConversationAdmin(admin.ModelAdmin):
|
|||||||
log["by"] == "khoj"
|
log["by"] == "khoj"
|
||||||
and log["intent"]
|
and log["intent"]
|
||||||
and log["intent"]["type"]
|
and log["intent"]["type"]
|
||||||
and log["intent"]["type"] == "text-to-image"
|
and (
|
||||||
|
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
|
||||||
|
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
|
||||||
|
)
|
||||||
):
|
):
|
||||||
log["message"] = "image redacted for space"
|
log["message"] = "inline image redacted for space"
|
||||||
chat_log[idx] = log
|
chat_log[idx] = log
|
||||||
modified_log["chat"] = chat_log
|
modified_log["chat"] = chat_log
|
||||||
|
|
||||||
@@ -144,9 +148,12 @@ class ConversationAdmin(admin.ModelAdmin):
|
|||||||
log["by"] == "khoj"
|
log["by"] == "khoj"
|
||||||
and log["intent"]
|
and log["intent"]
|
||||||
and log["intent"]["type"]
|
and log["intent"]["type"]
|
||||||
and log["intent"]["type"] == "text-to-image"
|
and (
|
||||||
|
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
|
||||||
|
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
|
||||||
|
)
|
||||||
):
|
):
|
||||||
updated_log["message"] = "image redacted for space"
|
updated_log["message"] = "inline image redacted for space"
|
||||||
chat_log[idx] = updated_log
|
chat_log[idx] = updated_log
|
||||||
return_log["chat"] = chat_log
|
return_log["chat"] = chat_log
|
||||||
|
|
||||||
|
|||||||
@@ -6,13 +6,15 @@ import io
|
|||||||
from django.db import migrations
|
from django.db import migrations
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from khoj.utils.helpers import ImageIntentType
|
||||||
|
|
||||||
|
|
||||||
def convert_png_images_to_webp(apps, schema_editor):
|
def convert_png_images_to_webp(apps, schema_editor):
|
||||||
# Get the model from the versioned app registry to ensure the correct version is used
|
# Get the model from the versioned app registry to ensure the correct version is used
|
||||||
Conversations = apps.get_model("database", "Conversation")
|
Conversations = apps.get_model("database", "Conversation")
|
||||||
for conversation in Conversations.objects.all():
|
for conversation in Conversations.objects.all():
|
||||||
for chat in conversation.conversation_log["chat"]:
|
for chat in conversation.conversation_log["chat"]:
|
||||||
if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image":
|
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value:
|
||||||
# Decode the base64 encoded PNG image
|
# Decode the base64 encoded PNG image
|
||||||
decoded_image = base64.b64decode(chat["message"])
|
decoded_image = base64.b64decode(chat["message"])
|
||||||
|
|
||||||
@@ -25,10 +27,10 @@ def convert_png_images_to_webp(apps, schema_editor):
|
|||||||
# Encode the WebP image back to base64
|
# Encode the WebP image back to base64
|
||||||
webp_image_bytes = webp_image_io.getvalue()
|
webp_image_bytes = webp_image_io.getvalue()
|
||||||
chat["message"] = base64.b64encode(webp_image_bytes).decode()
|
chat["message"] = base64.b64encode(webp_image_bytes).decode()
|
||||||
chat["intent"]["type"] = "text-to-image-v3"
|
chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE_V3.value
|
||||||
webp_image_io.close()
|
webp_image_io.close()
|
||||||
|
|
||||||
if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image2":
|
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value:
|
||||||
print("❗️ Please MANUALLY update PNG images created by Khoj in your AWS S3 bucket to WebP format.")
|
print("❗️ Please MANUALLY update PNG images created by Khoj in your AWS S3 bucket to WebP format.")
|
||||||
# Convert PNG url to WebP url
|
# Convert PNG url to WebP url
|
||||||
chat["message"] = chat["message"].replace(".png", ".webp")
|
chat["message"] = chat["message"].replace(".png", ".webp")
|
||||||
@@ -42,7 +44,7 @@ def convert_webp_images_to_png(apps, schema_editor):
|
|||||||
Conversations = apps.get_model("database", "Conversation")
|
Conversations = apps.get_model("database", "Conversation")
|
||||||
for conversation in Conversations.objects.all():
|
for conversation in Conversations.objects.all():
|
||||||
for chat in conversation.conversation_log["chat"]:
|
for chat in conversation.conversation_log["chat"]:
|
||||||
if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image":
|
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value:
|
||||||
# Decode the base64 encoded PNG image
|
# Decode the base64 encoded PNG image
|
||||||
decoded_image = base64.b64decode(chat["message"])
|
decoded_image = base64.b64decode(chat["message"])
|
||||||
|
|
||||||
@@ -55,10 +57,10 @@ def convert_webp_images_to_png(apps, schema_editor):
|
|||||||
# Encode the WebP image back to base64
|
# Encode the WebP image back to base64
|
||||||
webp_image_bytes = webp_image_io.getvalue()
|
webp_image_bytes = webp_image_io.getvalue()
|
||||||
chat["message"] = base64.b64encode(webp_image_bytes).decode()
|
chat["message"] = base64.b64encode(webp_image_bytes).decode()
|
||||||
chat["intent"]["type"] = "text-to-image"
|
chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE.value
|
||||||
webp_image_io.close()
|
webp_image_io.close()
|
||||||
|
|
||||||
if chat["by"] == "khoj" and chat["intent"]["type"] == "text-to-image2":
|
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value:
|
||||||
# Convert WebP url to PNG url
|
# Convert WebP url to PNG url
|
||||||
print("❗️ Please MANUALLY update WebP images created by Khoj in your AWS S3 bucket to PNG format.")
|
print("❗️ Please MANUALLY update WebP images created by Khoj in your AWS S3 bucket to PNG format.")
|
||||||
chat["message"] = chat["message"].replace(".webp", ".png")
|
chat["message"] = chat["message"].replace(".webp", ".png")
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ from khoj.utils import state
|
|||||||
from khoj.utils.config import OfflineChatProcessorModel
|
from khoj.utils.config import OfflineChatProcessorModel
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
|
ImageIntentType,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_valid_url,
|
is_valid_url,
|
||||||
log_telemetry,
|
log_telemetry,
|
||||||
@@ -520,14 +521,14 @@ async def text_to_image(
|
|||||||
image = None
|
image = None
|
||||||
response = None
|
response = None
|
||||||
image_url = None
|
image_url = None
|
||||||
intent_type = "text-to-image-v3"
|
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||||
|
|
||||||
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
||||||
if not text_to_image_config:
|
if not text_to_image_config:
|
||||||
# If the user has not configured a text to image model, return an unsupported on server error
|
# If the user has not configured a text to image model, return an unsupported on server error
|
||||||
status_code = 501
|
status_code = 501
|
||||||
message = "Failed to generate image. Setup image generation on the server."
|
message = "Failed to generate image. Setup image generation on the server."
|
||||||
return image_url or image, status_code, message, intent_type
|
return image_url or image, status_code, message, intent_type.value
|
||||||
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
logger.info("Generating image with OpenAI")
|
logger.info("Generating image with OpenAI")
|
||||||
text2image_model = text_to_image_config.model_name
|
text2image_model = text_to_image_config.model_name
|
||||||
@@ -572,24 +573,24 @@ async def text_to_image(
|
|||||||
with timer("Upload image to S3", logger):
|
with timer("Upload image to S3", logger):
|
||||||
image_url = upload_image(webp_image_bytes, user.uuid)
|
image_url = upload_image(webp_image_bytes, user.uuid)
|
||||||
if image_url:
|
if image_url:
|
||||||
intent_type = "text-to-image-v2"
|
intent_type = ImageIntentType.TEXT_TO_IMAGE2
|
||||||
else:
|
else:
|
||||||
intent_type = "text-to-image-v3"
|
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
|
||||||
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||||
|
|
||||||
return image_url or image, status_code, improved_image_prompt, intent_type
|
return image_url or image, status_code, improved_image_prompt, intent_type.value
|
||||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||||
if "content_policy_violation" in e.message:
|
if "content_policy_violation" in e.message:
|
||||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
||||||
return image_url or image, status_code, message, intent_type
|
return image_url or image, status_code, message, intent_type.value
|
||||||
else:
|
else:
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
return image_url or image, status_code, message, intent_type
|
return image_url or image, status_code, message, intent_type.value
|
||||||
return image_url or image, status_code, response, intent_type
|
return image_url or image, status_code, response, intent_type.value
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
|
|||||||
@@ -329,6 +329,20 @@ mode_descriptions_for_llm = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ImageIntentType(Enum):
|
||||||
|
"""
|
||||||
|
Chat message intent by Khoj for image responses.
|
||||||
|
Marks the schema used to reference image in chat messages
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Images as Inline PNG
|
||||||
|
TEXT_TO_IMAGE = "text-to-image"
|
||||||
|
# Images as URLs
|
||||||
|
TEXT_TO_IMAGE2 = "text-to-image2"
|
||||||
|
# Images as Inline WebP
|
||||||
|
TEXT_TO_IMAGE_V3 = "text-to-image-v3"
|
||||||
|
|
||||||
|
|
||||||
def generate_random_name():
|
def generate_random_name():
|
||||||
# List of adjectives and nouns to choose from
|
# List of adjectives and nouns to choose from
|
||||||
adjectives = [
|
adjectives = [
|
||||||
|
|||||||
Reference in New Issue
Block a user