From 82eac5a0438a6603f73a57d7f0dd28a3fb374112 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 23 Oct 2024 03:52:46 -0700 Subject: [PATCH] Make the get image from url function more versatile and reusable It was previously added under the google utils. Now it can be used by other conversation processors as well. The updated function - can get both base64 encoded and PIL formatted images from url - will return the media type of the image as well in response --- .../processor/conversation/google/utils.py | 17 ++----------- src/khoj/processor/conversation/utils.py | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index d19b02f2..a4041a94 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -1,11 +1,8 @@ import logging import random -from io import BytesIO from threading import Thread import google.generativeai as genai -import PIL.Image -import requests from google.generativeai.types.answer_types import FinishReason from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.safety_types import ( @@ -22,7 +19,7 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator +from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url from khoj.utils.helpers import is_none_or_empty logger = logging.getLogger(__name__) @@ -207,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = if isinstance(message.content, list): # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) message.content = [ - get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"] + get_image_from_url(item["image_url"]["url"])[0] if item["type"] == "image_url" else item["text"] for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1) ] elif isinstance(message.content, str): @@ -220,13 +217,3 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = messages[0].role = "user" return messages, system_prompt - - -def get_image_from_url(image_url: str) -> PIL.Image: - try: - response = requests.get(image_url) - response.raise_for_status() # Check if the request was successful - return PIL.Image.open(BytesIO(response.content)) - except requests.exceptions.RequestException as e: - logger.error(f"Failed to get image from URL {image_url}: {e}") - return None diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e8e96314..cb2c2ba3 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,10 +1,15 @@ +import base64 import logging import math +import mimetypes import queue from datetime import datetime +from io import BytesIO from time import perf_counter from typing import Any, Dict, List, Optional +import PIL.Image +import requests import tiktoken from langchain.schema import ChatMessage from llama_cpp.llama import Llama @@ -306,3 +311,22 @@ def reciprocal_conversation_to_chatml(message_pair): def remove_json_codeblock(response: str): """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models""" return response.removeprefix("```json").removesuffix("```") + + +def get_image_from_url(image_url: str, type="pil"): + try: + response = requests.get(image_url) + response.raise_for_status() # Check if the request was successful + + # Get content type from response or infer from URL + content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp" + + if type == "b64": + return base64.b64encode(response.content).decode("utf-8"), content_type + elif type == "pil": + return PIL.Image.open(BytesIO(response.content)), content_type + else: + raise ValueError(f"Invalid image type: {type}") + except requests.exceptions.RequestException as e: + logger.error(f"Failed to get image from URL {image_url}: {e}") + return None, None