Deduplicate, clean code for S3 images uploads

This commit is contained in:
Debanjum
2025-03-20 08:56:42 +05:30
parent f15a95dccf
commit 1ce1d2f5ab
3 changed files with 32 additions and 33 deletions

View File

@@ -12,7 +12,7 @@ from google.genai import types as gtypes
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
from khoj.database.models import Agent, KhojUser, TextToImageModelConfig from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_image from khoj.routers.storage import upload_generated_image_to_bucket
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import convert_image_to_webp, timer from khoj.utils.helpers import convert_image_to_webp, timer
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@@ -118,7 +118,7 @@ async def text_to_image(
# Decide how to store the generated image # Decide how to store the generated image
with timer("Upload image to S3", logger): with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid) image_url = upload_generated_image_to_bucket(webp_image_bytes, user.uuid)
if not image_url: if not image_url:
image = f"data:image/webp;base64,{base64.b64encode(webp_image_bytes).decode('utf-8')}" image = f"data:image/webp;base64,{base64.b64encode(webp_image_bytes).decode('utf-8')}"

View File

@@ -64,7 +64,7 @@ from khoj.routers.research import (
InformationCollectionIteration, InformationCollectionIteration,
execute_information_collection, execute_information_collection,
) )
from khoj.routers.storage import upload_image_to_bucket from khoj.routers.storage import upload_user_image_to_bucket
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import ( from khoj.utils.helpers import (
AsyncIteratorWrapper, AsyncIteratorWrapper,
@@ -674,7 +674,7 @@ async def chat(
base64_data = decoded_string.split(",", 1)[1] base64_data = decoded_string.split(",", 1)[1]
image_bytes = base64.b64decode(base64_data) image_bytes = base64.b64decode(base64_data)
webp_image_bytes = convert_image_to_webp(image_bytes) webp_image_bytes = convert_image_to_webp(image_bytes)
uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id) uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id)
if uploaded_image: if uploaded_image:
uploaded_images.append(uploaded_image) uploaded_images.append(uploaded_image)

View File

@@ -9,9 +9,10 @@ AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY")
# S3 supports serving assets via your domain. Khoj expects this to be used in production. To enable it: # S3 supports serving assets via your domain. Khoj expects this to be used in production. To enable it:
# 1. Your bucket name for images should be of the form sub.domain.tld. For example, generated.khoj.dev # 1. Your bucket name for images should be of the form sub.domain.tld. For example, generated.khoj.dev
# 2. Add CNAME entry to your domain's DNS records pointing to the S3 bucket. For example, CNAME generated.khoj.dev generated-khoj-dev.s3.amazonaws.com # 2. Add CNAME entry to your domain's DNS records pointing to the S3 bucket. For example, CNAME generated.khoj.dev generated-khoj-dev.s3.amazonaws.com
AWS_UPLOAD_IMAGE_BUCKET_NAME = os.getenv("AWS_IMAGE_UPLOAD_BUCKET") AWS_KHOJ_IMAGES_BUCKET_NAME = os.getenv("AWS_IMAGE_UPLOAD_BUCKET")
AWS_USER_IMAGES_BUCKET_NAME = os.getenv("AWS_USER_UPLOADED_IMAGES_BUCKET_NAME")
aws_enabled = AWS_ACCESS_KEY is not None and AWS_SECRET_KEY is not None and AWS_UPLOAD_IMAGE_BUCKET_NAME is not None aws_enabled = AWS_ACCESS_KEY is not None and AWS_SECRET_KEY is not None
if aws_enabled: if aws_enabled:
from boto3 import client from boto3 import client
@@ -19,45 +20,43 @@ if aws_enabled:
s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY) s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
def upload_image(image: bytes, user_id: uuid.UUID): def upload_image_to_bucket(webp_image: bytes, user_id: uuid.UUID, bucket_name: str):
"""Upload the image to the S3 bucket""" """Upload webp image to an S3 bucket"""
if not aws_enabled: if not aws_enabled:
logger.info("AWS is not enabled. Skipping image upload") logger.info("AWS is not enabled. Skipping image upload")
return None return None
if not bucket_name:
image_key = f"{user_id}/{uuid.uuid4()}.webp" logger.error(f"{bucket_name} is not set")
try:
s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=image, ACL="public-read")
url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}/{image_key}"
return url
except Exception as e:
logger.error(f"Failed to upload image to S3: {e}")
return None
AWS_USER_UPLOADED_IMAGES_BUCKET_NAME = os.getenv("AWS_USER_UPLOADED_IMAGES_BUCKET_NAME")
def upload_image_to_bucket(image: bytes, user_id: uuid.UUID):
"""Upload the image to the S3 bucket"""
if not aws_enabled:
logger.info("AWS is not enabled. Skipping image upload")
return None return None
image_key = f"{user_id}/{uuid.uuid4()}.webp" image_key = f"{user_id}/{uuid.uuid4()}.webp"
if not AWS_USER_UPLOADED_IMAGES_BUCKET_NAME:
logger.error("AWS_USER_UPLOADED_IMAGES_BUCKET_NAME is not set")
return None
try: try:
s3_client.put_object( s3_client.put_object(
Bucket=AWS_USER_UPLOADED_IMAGES_BUCKET_NAME, Bucket=bucket_name,
Key=image_key, Key=image_key,
Body=image, Body=webp_image,
ACL="public-read", ACL="public-read",
ContentType="image/webp", ContentType="image/webp",
) )
return f"https://{AWS_USER_UPLOADED_IMAGES_BUCKET_NAME}/{image_key}" return f"https://{bucket_name}/{image_key}"
except Exception as e: except Exception as e:
logger.error(f"Failed to upload image to S3: {e}") logger.error(f"Failed to upload image to S3: {e}")
return None return None
def upload_generated_image_to_bucket(image: bytes, user_id: uuid.UUID):
"""Upload khoj generated image to an S3 bucket"""
return upload_image_to_bucket(
webp_image=image,
user_id=user_id,
bucket_name=AWS_KHOJ_IMAGES_BUCKET_NAME,
)
def upload_user_image_to_bucket(image: bytes, user_id: uuid.UUID):
"""Upload user attached image to an S3 bucket"""
return upload_image_to_bucket(
webp_image=image,
user_id=user_id,
bucket_name=AWS_USER_IMAGES_BUCKET_NAME,
)