mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 21:29:11 +00:00
Store Khoj generated images as webp instead of png for faster loading
This commit is contained in:
@@ -425,8 +425,7 @@ async def websocket_endpoint(
|
|||||||
api="chat",
|
api="chat",
|
||||||
metadata={"conversation_command": conversation_commands[0].value},
|
metadata={"conversation_command": conversation_commands[0].value},
|
||||||
)
|
)
|
||||||
intent_type = "text-to-image"
|
image, status_code, improved_image_prompt, intent_type = await text_to_image(
|
||||||
image, status_code, improved_image_prompt, image_url = await text_to_image(
|
|
||||||
q,
|
q,
|
||||||
user,
|
user,
|
||||||
meta_log,
|
meta_log,
|
||||||
@@ -445,9 +444,6 @@ async def websocket_endpoint(
|
|||||||
await send_complete_llm_response(json.dumps(content_obj))
|
await send_complete_llm_response(json.dumps(content_obj))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if image_url:
|
|
||||||
intent_type = "text-to-image2"
|
|
||||||
image = image_url
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
image,
|
image,
|
||||||
@@ -621,17 +617,13 @@ async def chat(
|
|||||||
metadata={"conversation_command": conversation_commands[0].value},
|
metadata={"conversation_command": conversation_commands[0].value},
|
||||||
**common.__dict__,
|
**common.__dict__,
|
||||||
)
|
)
|
||||||
intent_type = "text-to-image"
|
image, status_code, improved_image_prompt, intent_type = await text_to_image(
|
||||||
image, status_code, improved_image_prompt, image_url = await text_to_image(
|
|
||||||
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
|
||||||
)
|
)
|
||||||
if image is None:
|
if image is None:
|
||||||
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
|
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
|
||||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||||
|
|
||||||
if image_url:
|
|
||||||
intent_type = "text-to-image2"
|
|
||||||
image = image_url
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
image,
|
image,
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
@@ -18,6 +20,7 @@ from typing import (
|
|||||||
|
|
||||||
import openai
|
import openai
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
|
from PIL import Image
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
|
|
||||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
|
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
|
||||||
@@ -508,18 +511,19 @@ async def text_to_image(
|
|||||||
references: List[str],
|
references: List[str],
|
||||||
online_results: Dict[str, Any],
|
online_results: Dict[str, Any],
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
) -> Tuple[Optional[str], int, Optional[str], Optional[str]]:
|
) -> Tuple[Optional[str], int, Optional[str], str]:
|
||||||
status_code = 200
|
status_code = 200
|
||||||
image = None
|
image = None
|
||||||
response = None
|
response = None
|
||||||
image_url = None
|
image_url = None
|
||||||
|
intent_type = "text-to-image-v3"
|
||||||
|
|
||||||
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
||||||
if not text_to_image_config:
|
if not text_to_image_config:
|
||||||
# If the user has not configured a text to image model, return an unsupported on server error
|
# If the user has not configured a text to image model, return an unsupported on server error
|
||||||
status_code = 501
|
status_code = 501
|
||||||
message = "Failed to generate image. Setup image generation on the server."
|
message = "Failed to generate image. Setup image generation on the server."
|
||||||
return image, status_code, message, image_url
|
return image_url or image, status_code, message, intent_type
|
||||||
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||||
logger.info("Generating image with OpenAI")
|
logger.info("Generating image with OpenAI")
|
||||||
text2image_model = text_to_image_config.model_name
|
text2image_model = text_to_image_config.model_name
|
||||||
@@ -550,21 +554,38 @@ async def text_to_image(
|
|||||||
)
|
)
|
||||||
image = response.data[0].b64_json
|
image = response.data[0].b64_json
|
||||||
|
|
||||||
|
with timer("Convert image to webp", logger):
|
||||||
|
# Convert png to webp for faster loading
|
||||||
|
decoded_image = base64.b64decode(image)
|
||||||
|
image_io = io.BytesIO(decoded_image)
|
||||||
|
png_image = Image.open(image_io)
|
||||||
|
webp_image_io = io.BytesIO()
|
||||||
|
png_image.save(webp_image_io, "WEBP")
|
||||||
|
webp_image_bytes = webp_image_io.getvalue()
|
||||||
|
webp_image_io.close()
|
||||||
|
image_io.close()
|
||||||
|
|
||||||
with timer("Upload image to S3", logger):
|
with timer("Upload image to S3", logger):
|
||||||
image_url = upload_image(image, user.uuid)
|
image_url = upload_image(webp_image_bytes, user.uuid)
|
||||||
return image, status_code, improved_image_prompt, image_url
|
if image_url:
|
||||||
|
intent_type = "text-to-image-v2"
|
||||||
|
else:
|
||||||
|
intent_type = "text-to-image-v3"
|
||||||
|
image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
return image_url or image, status_code, improved_image_prompt, intent_type
|
||||||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
|
||||||
if "content_policy_violation" in e.message:
|
if "content_policy_violation" in e.message:
|
||||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
||||||
return image, status_code, message, image_url
|
return image_url or image, status_code, message, intent_type
|
||||||
else:
|
else:
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
return image, status_code, message, image_url
|
return image_url or image, status_code, message, intent_type
|
||||||
return image, status_code, response, image_url
|
return image_url or image, status_code, response, intent_type
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import base64
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
@@ -17,16 +16,15 @@ if aws_enabled:
|
|||||||
s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
|
s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
|
||||||
|
|
||||||
|
|
||||||
def upload_image(image: str, user_id: uuid.UUID):
|
def upload_image(image: bytes, user_id: uuid.UUID):
|
||||||
"""Upload the image to the S3 bucket"""
|
"""Upload the image to the S3 bucket"""
|
||||||
if not aws_enabled:
|
if not aws_enabled:
|
||||||
logger.info("AWS is not enabled. Skipping image upload")
|
logger.info("AWS is not enabled. Skipping image upload")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
decoded_image = base64.b64decode(image)
|
image_key = f"{user_id}/{uuid.uuid4()}.webp"
|
||||||
image_key = f"{user_id}/{uuid.uuid4()}.png"
|
|
||||||
try:
|
try:
|
||||||
s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=decoded_image, ACL="public-read")
|
s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=image, ACL="public-read")
|
||||||
url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}"
|
url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}"
|
||||||
return url
|
return url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user