mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 13:21:18 +00:00
Fix passing inline images to vision models
- Fix regression: Inline images were not getting passed to the AI models since #992 - Format inline images passed to Gemini models correctly - Format inline images passed to Anthropic models correctly Verified vision working with inline and url images for OpenAI, Anthropic and Gemini models. Resolves #1112
This commit is contained in:
@@ -15,6 +15,7 @@ from tenacity import (
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
get_image_from_base64,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
@@ -232,7 +233,11 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=Non
|
||||
if part["type"] == "text":
|
||||
content.append({"type": "text", "text": part["text"]})
|
||||
elif part["type"] == "image_url":
|
||||
image = get_image_from_url(part["image_url"]["url"], type="b64")
|
||||
image_data = part["image_url"]["url"]
|
||||
if image_data.startswith("http"):
|
||||
image = get_image_from_url(image_data, type="b64")
|
||||
else:
|
||||
image = get_image_from_base64(image_data, type="b64")
|
||||
# Prefix each image with text block enumerating the image number
|
||||
# This helps the model reference the image in its response. Recommended by Anthropic
|
||||
content.extend(
|
||||
|
||||
@@ -18,6 +18,7 @@ from tenacity import (
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
commit_conversation_trace,
|
||||
get_image_from_base64,
|
||||
get_image_from_url,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
@@ -245,7 +246,11 @@ def format_messages_for_gemini(
|
||||
message_content = []
|
||||
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1):
|
||||
if item["type"] == "image_url":
|
||||
image = get_image_from_url(item["image_url"]["url"], type="bytes")
|
||||
image_data = item["image_url"]["url"]
|
||||
if image_data.startswith("http"):
|
||||
image = get_image_from_url(image_data, type="bytes")
|
||||
else:
|
||||
image = get_image_from_base64(image_data, type="bytes")
|
||||
message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)]
|
||||
else:
|
||||
message_content += [gtypes.Part.from_text(text=item.get("text", ""))]
|
||||
|
||||
@@ -345,8 +345,7 @@ def construct_structured_message(
|
||||
constructed_messages.append({"type": "text", "text": attached_file_context})
|
||||
if vision_enabled and images:
|
||||
for image in images:
|
||||
if image.startswith("https://"):
|
||||
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
|
||||
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
|
||||
return constructed_messages
|
||||
|
||||
if not is_none_or_empty(attached_file_context):
|
||||
@@ -664,6 +663,23 @@ class ImageWithType:
|
||||
type: str
|
||||
|
||||
|
||||
def get_image_from_base64(image: str, type="b64"):
|
||||
# Extract image type and base64 data from inline image data
|
||||
image_base64 = image.split(",", 1)[1]
|
||||
image_type = image.split(";", 1)[0].split(":", 1)[1]
|
||||
|
||||
# Convert image to desired format
|
||||
if type == "b64":
|
||||
return ImageWithType(content=image_base64, type=image_type)
|
||||
elif type == "pil":
|
||||
image_data = base64.b64decode(image_base64)
|
||||
image_pil = PIL.Image.open(BytesIO(image_data))
|
||||
return ImageWithType(content=image_pil, type=image_type)
|
||||
elif type == "bytes":
|
||||
image_data = base64.b64decode(image_base64)
|
||||
return ImageWithType(content=image_data, type=image_type)
|
||||
|
||||
|
||||
def get_image_from_url(image_url: str, type="pil"):
|
||||
try:
|
||||
response = requests.get(image_url)
|
||||
|
||||
@@ -675,8 +675,10 @@ async def chat(
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||
uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||
if uploaded_image:
|
||||
uploaded_images.append(uploaded_image)
|
||||
if not uploaded_image:
|
||||
base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||
uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
|
||||
uploaded_images.append(uploaded_image)
|
||||
|
||||
query_files: Dict[str, str] = {}
|
||||
if raw_query_files:
|
||||
|
||||
Reference in New Issue
Block a user