mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Fix passing inline images to vision models
- Fix regression: Inline images were not getting passed to the AI models since #992 - Format inline images passed to Gemini models correctly - Format inline images passed to Anthropic models correctly Verified vision working with inline and url images for OpenAI, Anthropic and Gemini models. Resolves #1112
This commit is contained in:
@@ -15,6 +15,7 @@ from tenacity import (
|
|||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ThreadedGenerator,
|
ThreadedGenerator,
|
||||||
commit_conversation_trace,
|
commit_conversation_trace,
|
||||||
|
get_image_from_base64,
|
||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
@@ -232,7 +233,11 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=Non
|
|||||||
if part["type"] == "text":
|
if part["type"] == "text":
|
||||||
content.append({"type": "text", "text": part["text"]})
|
content.append({"type": "text", "text": part["text"]})
|
||||||
elif part["type"] == "image_url":
|
elif part["type"] == "image_url":
|
||||||
image = get_image_from_url(part["image_url"]["url"], type="b64")
|
image_data = part["image_url"]["url"]
|
||||||
|
if image_data.startswith("http"):
|
||||||
|
image = get_image_from_url(image_data, type="b64")
|
||||||
|
else:
|
||||||
|
image = get_image_from_base64(image_data, type="b64")
|
||||||
# Prefix each image with text block enumerating the image number
|
# Prefix each image with text block enumerating the image number
|
||||||
# This helps the model reference the image in its response. Recommended by Anthropic
|
# This helps the model reference the image in its response. Recommended by Anthropic
|
||||||
content.extend(
|
content.extend(
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from tenacity import (
|
|||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ThreadedGenerator,
|
ThreadedGenerator,
|
||||||
commit_conversation_trace,
|
commit_conversation_trace,
|
||||||
|
get_image_from_base64,
|
||||||
get_image_from_url,
|
get_image_from_url,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
@@ -245,7 +246,11 @@ def format_messages_for_gemini(
|
|||||||
message_content = []
|
message_content = []
|
||||||
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1):
|
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1):
|
||||||
if item["type"] == "image_url":
|
if item["type"] == "image_url":
|
||||||
image = get_image_from_url(item["image_url"]["url"], type="bytes")
|
image_data = item["image_url"]["url"]
|
||||||
|
if image_data.startswith("http"):
|
||||||
|
image = get_image_from_url(image_data, type="bytes")
|
||||||
|
else:
|
||||||
|
image = get_image_from_base64(image_data, type="bytes")
|
||||||
message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)]
|
message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)]
|
||||||
else:
|
else:
|
||||||
message_content += [gtypes.Part.from_text(text=item.get("text", ""))]
|
message_content += [gtypes.Part.from_text(text=item.get("text", ""))]
|
||||||
|
|||||||
@@ -345,7 +345,6 @@ def construct_structured_message(
|
|||||||
constructed_messages.append({"type": "text", "text": attached_file_context})
|
constructed_messages.append({"type": "text", "text": attached_file_context})
|
||||||
if vision_enabled and images:
|
if vision_enabled and images:
|
||||||
for image in images:
|
for image in images:
|
||||||
if image.startswith("https://"):
|
|
||||||
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
|
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
|
||||||
return constructed_messages
|
return constructed_messages
|
||||||
|
|
||||||
@@ -664,6 +663,23 @@ class ImageWithType:
|
|||||||
type: str
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_from_base64(image: str, type="b64"):
|
||||||
|
# Extract image type and base64 data from inline image data
|
||||||
|
image_base64 = image.split(",", 1)[1]
|
||||||
|
image_type = image.split(";", 1)[0].split(":", 1)[1]
|
||||||
|
|
||||||
|
# Convert image to desired format
|
||||||
|
if type == "b64":
|
||||||
|
return ImageWithType(content=image_base64, type=image_type)
|
||||||
|
elif type == "pil":
|
||||||
|
image_data = base64.b64decode(image_base64)
|
||||||
|
image_pil = PIL.Image.open(BytesIO(image_data))
|
||||||
|
return ImageWithType(content=image_pil, type=image_type)
|
||||||
|
elif type == "bytes":
|
||||||
|
image_data = base64.b64decode(image_base64)
|
||||||
|
return ImageWithType(content=image_data, type=image_type)
|
||||||
|
|
||||||
|
|
||||||
def get_image_from_url(image_url: str, type="pil"):
|
def get_image_from_url(image_url: str, type="pil"):
|
||||||
try:
|
try:
|
||||||
response = requests.get(image_url)
|
response = requests.get(image_url)
|
||||||
|
|||||||
@@ -675,7 +675,9 @@ async def chat(
|
|||||||
image_bytes = base64.b64decode(base64_data)
|
image_bytes = base64.b64decode(base64_data)
|
||||||
webp_image_bytes = convert_image_to_webp(image_bytes)
|
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||||
uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id)
|
uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||||
if uploaded_image:
|
if not uploaded_image:
|
||||||
|
base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
|
||||||
|
uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
|
||||||
uploaded_images.append(uploaded_image)
|
uploaded_images.append(uploaded_image)
|
||||||
|
|
||||||
query_files: Dict[str, str] = {}
|
query_files: Dict[str, str] = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user