diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index d19b02f2..a4041a94 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -1,11 +1,8 @@ import logging import random -from io import BytesIO from threading import Thread import google.generativeai as genai -import PIL.Image -import requests from google.generativeai.types.answer_types import FinishReason from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.safety_types import ( @@ -22,7 +19,7 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator +from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url from khoj.utils.helpers import is_none_or_empty logger = logging.getLogger(__name__) @@ -207,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = if isinstance(message.content, list): # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) message.content = [ - get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"] + get_image_from_url(item["image_url"]["url"])[0] if item["type"] == "image_url" else item["text"] for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1) ] elif isinstance(message.content, str): @@ -220,13 +217,3 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = messages[0].role = "user" return messages, system_prompt - - -def get_image_from_url(image_url: str) -> PIL.Image: - try: - response = requests.get(image_url) - response.raise_for_status() # Check if the request was successful - return PIL.Image.open(BytesIO(response.content)) - except requests.exceptions.RequestException as e: - logger.error(f"Failed to get image from URL {image_url}: {e}") - return None diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e8e96314..cb2c2ba3 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,10 +1,15 @@ +import base64 import logging import math +import mimetypes import queue from datetime import datetime +from io import BytesIO from time import perf_counter from typing import Any, Dict, List, Optional +import PIL.Image +import requests import tiktoken from langchain.schema import ChatMessage from llama_cpp.llama import Llama @@ -306,3 +311,22 @@ def reciprocal_conversation_to_chatml(message_pair): def remove_json_codeblock(response: str): """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models""" return response.removeprefix("```json").removesuffix("```") + + +def get_image_from_url(image_url: str, type="pil"): + try: + response = requests.get(image_url) + response.raise_for_status() # Check if the request was successful + + # Get content type from response or infer from URL + content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp" + + if type == "b64": + return base64.b64encode(response.content).decode("utf-8"), content_type + elif type == "pil": + return PIL.Image.open(BytesIO(response.content)), content_type + else: + raise ValueError(f"Invalid image type: {type}") + except requests.exceptions.RequestException as e: + logger.error(f"Failed to get image from URL {image_url}: {e}") + return None, None