diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index b022d0e2..8744cfe4 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -15,6 +15,7 @@ from tenacity import ( from khoj.processor.conversation.utils import ( ThreadedGenerator, commit_conversation_trace, + get_image_from_base64, get_image_from_url, ) from khoj.utils.helpers import ( @@ -232,7 +233,11 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=Non if part["type"] == "text": content.append({"type": "text", "text": part["text"]}) elif part["type"] == "image_url": - image = get_image_from_url(part["image_url"]["url"], type="b64") + image_data = part["image_url"]["url"] + if image_data.startswith("http"): + image = get_image_from_url(image_data, type="b64") + else: + image = get_image_from_base64(image_data, type="b64") # Prefix each image with text block enumerating the image number # This helps the model reference the image in its response. Recommended by Anthropic content.extend( diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index b1a5fe77..014faff1 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -18,6 +18,7 @@ from tenacity import ( from khoj.processor.conversation.utils import ( ThreadedGenerator, commit_conversation_trace, + get_image_from_base64, get_image_from_url, ) from khoj.utils.helpers import ( @@ -245,7 +246,11 @@ def format_messages_for_gemini( message_content = [] for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1): if item["type"] == "image_url": - image = get_image_from_url(item["image_url"]["url"], type="bytes") + image_data = item["image_url"]["url"] + if image_data.startswith("http"): + image = get_image_from_url(image_data, type="bytes") + else: + image = get_image_from_base64(image_data, type="bytes") message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)] else: message_content += [gtypes.Part.from_text(text=item.get("text", ""))] diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index de82f067..dab26094 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -345,8 +345,7 @@ def construct_structured_message( constructed_messages.append({"type": "text", "text": attached_file_context}) if vision_enabled and images: for image in images: - if image.startswith("https://"): - constructed_messages.append({"type": "image_url", "image_url": {"url": image}}) + constructed_messages.append({"type": "image_url", "image_url": {"url": image}}) return constructed_messages if not is_none_or_empty(attached_file_context): @@ -664,6 +663,23 @@ class ImageWithType: type: str +def get_image_from_base64(image: str, type="b64"): + # Extract image type and base64 data from inline image data + image_base64 = image.split(",", 1)[1] + image_type = image.split(";", 1)[0].split(":", 1)[1] + + # Convert image to desired format + if type == "b64": + return ImageWithType(content=image_base64, type=image_type) + elif type == "pil": + image_data = base64.b64decode(image_base64) + image_pil = PIL.Image.open(BytesIO(image_data)) + return ImageWithType(content=image_pil, type=image_type) + elif type == "bytes": + image_data = base64.b64decode(image_base64) + return ImageWithType(content=image_data, type=image_type) + + def get_image_from_url(image_url: str, type="pil"): try: response = requests.get(image_url) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index a63db09f..acdf48b5 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -675,8 +675,10 @@ async def chat( image_bytes = base64.b64decode(base64_data) webp_image_bytes = convert_image_to_webp(image_bytes) uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id) - if uploaded_image: - uploaded_images.append(uploaded_image) + if not uploaded_image: + base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8") + uploaded_image = f"data:image/webp;base64,{base64_webp_image}" + uploaded_images.append(uploaded_image) query_files: Dict[str, str] = {} if raw_query_files: