diff --git a/src/khoj/database/management/commands/serve_generated_images_from_server.py b/src/khoj/database/management/commands/serve_generated_images_from_server.py new file mode 100644 index 00000000..0ed19128 --- /dev/null +++ b/src/khoj/database/management/commands/serve_generated_images_from_server.py @@ -0,0 +1,60 @@ +from django.core.management.base import BaseCommand + +from khoj.database.models import Conversation +from khoj.utils.helpers import ImageIntentType, is_none_or_empty + + +class Command(BaseCommand): + help = "Serve Khoj generated images from a different URL." + + def add_arguments(self, parser): + # Pass Source URL + parser.add_argument( + "--source", + action="store", + help="URL from which generated images are currently served.", + ) + # Pass Destination URL + parser.add_argument("--destination", action="store", help="URL to serve generated image from going forward.") + + # Add a new argument 'reverse' to the command + parser.add_argument( + "--reverse", + action="store_true", + help="Revert to serve generated images from source instead of destination URL.", + ) + + def handle(self, *args, **options): + updated_count = 0 + if not options.get("source") or not options.get("destination"): + self.stdout.write( + self.style.ERROR( + "AWS_IMAGE_UPLOAD_BUCKET environment variable or --source, --destination args needs to be set." + ) + ) + return + + destination = options["source"] if options["reverse"] else options["destination"] + source = options["destination"] if options["reverse"] else options["source"] + for conversation in Conversation.objects.all(): + conversation_updated = False + for chat in conversation.conversation_log.get("chat", []): + if ( + chat.get("by", "") == "khoj" + and not is_none_or_empty(chat.get("message")) + and chat.get("message", "").startswith(source) + and chat.get("intent", {}).get("type", "") == ImageIntentType.TEXT_TO_IMAGE2.value + ): + if chat.get("message", "").endswith(".webp"): + # Convert source url to destination url + chat["message"] = chat["message"].replace(source, destination) + conversation_updated = True + updated_count += 1 + + if conversation_updated: + print("Save the updated conversation") + conversation.save() + + if updated_count > 0: + success = f"Successfully converted {updated_count} image URLs from {source} to {destination}.".strip() + self.stdout.write(self.style.SUCCESS(success)) diff --git a/src/khoj/routers/storage.py b/src/khoj/routers/storage.py index 9a5d448f..8d7b08e5 100644 --- a/src/khoj/routers/storage.py +++ b/src/khoj/routers/storage.py @@ -6,6 +6,9 @@ logger = logging.getLogger(__name__) AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY") AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY") +# S3 supports serving assets via your domain. Khoj expects this to be used in production. To enable it: +# 1. Your bucket name for images should be of the form sub.domain.tld. For example, generated.khoj.dev +# 2. Add CNAME entry to your domain's DNS records pointing to the S3 bucket. For example, CNAME generated.khoj.dev generated-khoj-dev.s3.amazonaws.com AWS_UPLOAD_IMAGE_BUCKET_NAME = os.getenv("AWS_IMAGE_UPLOAD_BUCKET") aws_enabled = AWS_ACCESS_KEY is not None and AWS_SECRET_KEY is not None and AWS_UPLOAD_IMAGE_BUCKET_NAME is not None @@ -25,7 +28,7 @@ def upload_image(image: bytes, user_id: uuid.UUID): image_key = f"{user_id}/{uuid.uuid4()}.webp" try: s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=image, ACL="public-read") - url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}" + url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}/{image_key}" return url except Exception as e: logger.error(f"Failed to upload image to S3: {e}")