From d21f22ffa1517dfe079a7f50f3c84699217c131e Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 13 Apr 2024 13:03:32 +0530 Subject: [PATCH] Store Khoj generated images as webp instead of png for faster loading --- src/khoj/routers/api_chat.py | 12 ++---------- src/khoj/routers/helpers.py | 35 ++++++++++++++++++++++++++++------- src/khoj/routers/storage.py | 8 +++----- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 4e7a8cc9..76cf6d12 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -425,8 +425,7 @@ async def websocket_endpoint( api="chat", metadata={"conversation_command": conversation_commands[0].value}, ) - intent_type = "text-to-image" - image, status_code, improved_image_prompt, image_url = await text_to_image( + image, status_code, improved_image_prompt, intent_type = await text_to_image( q, user, meta_log, @@ -445,9 +444,6 @@ async def websocket_endpoint( await send_complete_llm_response(json.dumps(content_obj)) continue - if image_url: - intent_type = "text-to-image2" - image = image_url await sync_to_async(save_to_conversation_log)( q, image, @@ -621,17 +617,13 @@ async def chat( metadata={"conversation_command": conversation_commands[0].value}, **common.__dict__, ) - intent_type = "text-to-image" - image, status_code, improved_image_prompt, image_url = await text_to_image( + image, status_code, improved_image_prompt, intent_type = await text_to_image( q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results ) if image is None: content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt} return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) - if image_url: - intent_type = "text-to-image2" - image = image_url await sync_to_async(save_to_conversation_log)( q, image, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index f3be3162..cbf29c02 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,4 +1,6 @@ import asyncio +import base64 +import io import json import logging from concurrent.futures import ThreadPoolExecutor @@ -18,6 +20,7 @@ from typing import ( import openai from fastapi import Depends, Header, HTTPException, Request, UploadFile +from PIL import Image from starlette.authentication import has_required_scope from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters @@ -508,18 +511,19 @@ async def text_to_image( references: List[str], online_results: Dict[str, Any], send_status_func: Optional[Callable] = None, -) -> Tuple[Optional[str], int, Optional[str], Optional[str]]: +) -> Tuple[Optional[str], int, Optional[str], str]: status_code = 200 image = None response = None image_url = None + intent_type = "text-to-image-v3" text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 message = "Failed to generate image. Setup image generation on the server." - return image, status_code, message, image_url + return image_url or image, status_code, message, intent_type elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: logger.info("Generating image with OpenAI") text2image_model = text_to_image_config.model_name @@ -550,21 +554,38 @@ async def text_to_image( ) image = response.data[0].b64_json + with timer("Convert image to webp", logger): + # Convert png to webp for faster loading + decoded_image = base64.b64decode(image) + image_io = io.BytesIO(decoded_image) + png_image = Image.open(image_io) + webp_image_io = io.BytesIO() + png_image.save(webp_image_io, "WEBP") + webp_image_bytes = webp_image_io.getvalue() + webp_image_io.close() + image_io.close() + with timer("Upload image to S3", logger): - image_url = upload_image(image, user.uuid) - return image, status_code, improved_image_prompt, image_url + image_url = upload_image(webp_image_bytes, user.uuid) + if image_url: + intent_type = "text-to-image-v2" + else: + intent_type = "text-to-image-v3" + image = base64.b64encode(webp_image_bytes).decode("utf-8") + + return image_url or image, status_code, improved_image_prompt, intent_type except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: if "content_policy_violation" in e.message: logger.error(f"Image Generation blocked by OpenAI: {e}") status_code = e.status_code # type: ignore message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore - return image, status_code, message, image_url + return image_url or image, status_code, message, intent_type else: logger.error(f"Image Generation failed with {e}", exc_info=True) message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore status_code = e.status_code # type: ignore - return image, status_code, message, image_url - return image, status_code, response, image_url + return image_url or image, status_code, message, intent_type + return image_url or image, status_code, response, intent_type class ApiUserRateLimiter: diff --git a/src/khoj/routers/storage.py b/src/khoj/routers/storage.py index 57c28c5a..9a5d448f 100644 --- a/src/khoj/routers/storage.py +++ b/src/khoj/routers/storage.py @@ -1,4 +1,3 @@ -import base64 import logging import os import uuid @@ -17,16 +16,15 @@ if aws_enabled: s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY) -def upload_image(image: str, user_id: uuid.UUID): +def upload_image(image: bytes, user_id: uuid.UUID): """Upload the image to the S3 bucket""" if not aws_enabled: logger.info("AWS is not enabled. Skipping image upload") return None - decoded_image = base64.b64decode(image) - image_key = f"{user_id}/{uuid.uuid4()}.png" + image_key = f"{user_id}/{uuid.uuid4()}.webp" try: - s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=decoded_image, ACL="public-read") + s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=image, ACL="public-read") url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}" return url except Exception as e: