From 82eac5a0438a6603f73a57d7f0dd28a3fb374112 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 23 Oct 2024 03:52:46 -0700 Subject: [PATCH 1/5] Make the get image from url function more versatile and reusable It was previously added under the google utils. Now it can be used by other conversation processors as well. The updated function - can get both base64 encoded and PIL formatted images from url - will return the media type of the image as well in response --- .../processor/conversation/google/utils.py | 17 ++----------- src/khoj/processor/conversation/utils.py | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index d19b02f2..a4041a94 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -1,11 +1,8 @@ import logging import random -from io import BytesIO from threading import Thread import google.generativeai as genai -import PIL.Image -import requests from google.generativeai.types.answer_types import FinishReason from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.safety_types import ( @@ -22,7 +19,7 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator +from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url from khoj.utils.helpers import is_none_or_empty logger = logging.getLogger(__name__) @@ -207,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = if isinstance(message.content, list): # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) message.content = [ - get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"] + get_image_from_url(item["image_url"]["url"])[0] if item["type"] == "image_url" else item["text"] for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1) ] elif isinstance(message.content, str): @@ -220,13 +217,3 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = messages[0].role = "user" return messages, system_prompt - - -def get_image_from_url(image_url: str) -> PIL.Image: - try: - response = requests.get(image_url) - response.raise_for_status() # Check if the request was successful - return PIL.Image.open(BytesIO(response.content)) - except requests.exceptions.RequestException as e: - logger.error(f"Failed to get image from URL {image_url}: {e}") - return None diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e8e96314..cb2c2ba3 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,10 +1,15 @@ +import base64 import logging import math +import mimetypes import queue from datetime import datetime +from io import BytesIO from time import perf_counter from typing import Any, Dict, List, Optional +import PIL.Image +import requests import tiktoken from langchain.schema import ChatMessage from llama_cpp.llama import Llama @@ -306,3 +311,22 @@ def reciprocal_conversation_to_chatml(message_pair): def remove_json_codeblock(response: str): """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models""" return response.removeprefix("```json").removesuffix("```") + + +def get_image_from_url(image_url: str, type="pil"): + try: + response = requests.get(image_url) + response.raise_for_status() # Check if the request was successful + + # Get content type from response or infer from URL + content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp" + + if type == "b64": + return base64.b64encode(response.content).decode("utf-8"), content_type + elif type == "pil": + return PIL.Image.open(BytesIO(response.content)), content_type + else: + raise ValueError(f"Invalid image type: {type}") + except requests.exceptions.RequestException as e: + logger.error(f"Failed to get image from URL {image_url}: {e}") + return None, None From 6fd50a5956d70adaab2cf351b256de02a5b5654a Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 23 Oct 2024 03:57:55 -0700 Subject: [PATCH 2/5] Reuse logic to format messages for chat with anthropic models --- .../conversation/anthropic/anthropic_chat.py | 22 ++------ .../processor/conversation/anthropic/utils.py | 50 ++++++++++++++++++- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index cb51abb4..b6d85726 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -11,6 +11,7 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.utils import ( anthropic_chat_completion_with_backoff, anthropic_completion_with_backoff, + format_messages_for_anthropic, ) from khoj.processor.conversation.utils import generate_chatml_messages_with_context from khoj.utils.helpers import ConversationCommand, is_none_or_empty @@ -101,17 +102,7 @@ def anthropic_send_message_to_model(messages, api_key, model): """ Send message to model """ - # Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter - system_prompt = None - - if len(messages) == 1: - messages[0].role = "user" - else: - system_prompt = "" - for message in messages.copy(): - if message.role == "system": - system_prompt += message.content - messages.remove(message) + messages, system_prompt = format_messages_for_anthropic(messages) # Get Response from GPT. Don't use response_type because Anthropic doesn't support it. return anthropic_completion_with_backoff( @@ -192,14 +183,7 @@ def converse_anthropic( model_type=ChatModelOptions.ModelType.ANTHROPIC, ) - if len(messages) > 1: - if messages[0].role == "assistant": - messages = messages[1:] - - for message in messages.copy(): - if message.role == "system": - system_prompt += message.content - messages.remove(message) + messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) logger.debug(f"Conversation Context for Claude: {truncated_messages}") diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 79ccac4e..cc020b0a 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -3,6 +3,7 @@ from threading import Thread from typing import Dict, List import anthropic +from langchain.schema import ChatMessage from tenacity import ( before_sleep_log, retry, @@ -11,7 +12,8 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator +from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url +from khoj.utils.helpers import is_none_or_empty logger = logging.getLogger(__name__) @@ -115,3 +117,49 @@ def anthropic_llm_thread( logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) finally: g.close() + + +def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=None): + """ + Format messages for Anthropic + """ + # Extract system prompt + system_prompt = system_prompt or "" + for message in messages.copy(): + if message.role == "system": + system_prompt += message.content + messages.remove(message) + system_prompt = None if is_none_or_empty(system_prompt) else system_prompt + + # Anthropic requires the first message to be a 'user' message + if len(messages) == 1: + messages[0].role = "user" + elif len(messages) > 1 and messages[0].role == "assistant": + messages = messages[1:] + + # Convert image urls to base64 encoded images in Anthropic message format + for message in messages: + if isinstance(message.content, list): + content = [] + # Sort the content as preferred if text comes after images + message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1) + for idx, part in enumerate(message.content): + if part["type"] == "text": + content.append({"type": "text", "text": part["text"]}) + elif part["type"] == "image_url": + b64_image, media_type = get_image_from_url(part["image_url"]["url"], type="b64") + content.extend( + [ + { + "type": "text", + "text": f"Image {idx + 1}:", + }, + { + "type": "image", + "source": {"type": "base64", "media_type": media_type, "data": b64_image}, + }, + ] + ) + message.content = content + + return messages, system_prompt From abad5348a06e87aca8ececdabca2bd90055bebbc Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 23 Oct 2024 04:00:44 -0700 Subject: [PATCH 3/5] Give Vision to Anthropic models in Khoj --- .../conversation/anthropic/anthropic_chat.py | 20 +++++++++++++++++-- src/khoj/processor/conversation/utils.py | 6 +++++- src/khoj/routers/api.py | 2 ++ src/khoj/routers/helpers.py | 9 +++++++-- 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index b6d85726..5e403c7b 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -13,7 +13,10 @@ from khoj.processor.conversation.anthropic.utils import ( anthropic_completion_with_backoff, format_messages_for_anthropic, ) -from khoj.processor.conversation.utils import generate_chatml_messages_with_context +from khoj.processor.conversation.utils import ( + construct_structured_message, + generate_chatml_messages_with_context, +) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -28,6 +31,8 @@ def extract_questions_anthropic( temperature=0.7, location_data: LocationData = None, user: KhojUser = None, + query_images: Optional[list[str]] = None, + vision_enabled: bool = False, personality_context: Optional[str] = None, ): """ @@ -69,6 +74,13 @@ def extract_questions_anthropic( text=text, ) + prompt = construct_structured_message( + message=prompt, + images=query_images, + model_type=ChatModelOptions.ModelType.ANTHROPIC, + vision_enabled=vision_enabled, + ) + messages = [ChatMessage(content=prompt, role="user")] response = anthropic_completion_with_backoff( @@ -118,7 +130,7 @@ def converse_anthropic( user_query, online_results: Optional[Dict[str, Dict]] = None, conversation_log={}, - model: Optional[str] = "claude-instant-1.2", + model: Optional[str] = "claude-3-5-sonnet-20241022", api_key: Optional[str] = None, completion_func=None, conversation_commands=[ConversationCommand.Default], @@ -127,6 +139,8 @@ def converse_anthropic( location_data: LocationData = None, user_name: str = None, agent: Agent = None, + query_images: Optional[list[str]] = None, + vision_available: bool = False, ): """ Converse with user using Anthropic's Claude @@ -180,6 +194,8 @@ def converse_anthropic( model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, + query_images=query_images, + vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.ANTHROPIC, ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index cb2c2ba3..943c5616 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -157,7 +157,11 @@ def construct_structured_message(message: str, images: list[str], model_type: st if not images or not vision_enabled: return message - if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]: + if model_type in [ + ChatModelOptions.ModelType.OPENAI, + ChatModelOptions.ModelType.GOOGLE, + ChatModelOptions.ModelType.ANTHROPIC, + ]: return [ {"type": "text", "text": message}, *[{"type": "image_url", "image_url": {"url": image}} for image in images], diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index c542b1f3..388024fa 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -447,11 +447,13 @@ async def extract_references_and_questions( chat_model = conversation_config.chat_model inferred_queries = extract_questions_anthropic( defiltered_query, + query_images=query_images, model=chat_model, api_key=api_key, conversation_log=meta_log, location_data=location_data, user=user, + vision_enabled=vision_enabled, personality_context=personality_context, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index c587c4bd..8425a09a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -825,10 +825,13 @@ async def send_message_to_model_wrapper( conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) vision_available = conversation_config.vision_enabled if not vision_available and query_images: + logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.") vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() if vision_enabled_config: conversation_config = vision_enabled_config vision_available = True + if vision_available and query_images: + logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.") subscribed = await ais_user_subscribed(user) chat_model = conversation_config.chat_model @@ -1109,8 +1112,9 @@ def generate_chat_response( chat_response = converse_anthropic( compiled_references, q, - online_results, - meta_log, + query_images=query_images, + online_results=online_results, + conversation_log=meta_log, model=conversation_config.chat_model, api_key=api_key, completion_func=partial_completion, @@ -1120,6 +1124,7 @@ def generate_chat_response( location_data=location_data, user_name=user_name, agent=agent, + vision_available=vision_available, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key From 8d588e0765c70fc7553ca908563e299342bd5897 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 23 Oct 2024 04:03:15 -0700 Subject: [PATCH 4/5] Encourage output mode chat actor to output only json and nothing else Latest claude model wanted to say more than just give the json output. The updated prompt encourages the model to ouput just json. This is similar to what is already being done for other prompts --- src/khoj/processor/conversation/prompts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 38db7477..7988cc43 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -619,7 +619,7 @@ AI: It's currently 28°C and partly cloudy in Bali. Q: Share a painting using the weather for Bali every morning. Khoj: {{"output": "automation"}} -Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. +Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. Do not say anything else. Chat History: {chat_history} From 01d740debd4d1857ec9bc6595d227dc17678449c Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 24 Oct 2024 17:49:37 -0700 Subject: [PATCH 5/5] Return typed image from image_with_url function for readability --- .../processor/conversation/anthropic/utils.py | 8 +++++--- src/khoj/processor/conversation/google/utils.py | 2 +- src/khoj/processor/conversation/utils.py | 16 +++++++++++++--- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index cc020b0a..a4a71a6d 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -141,13 +141,15 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=Non for message in messages: if isinstance(message.content, list): content = [] - # Sort the content as preferred if text comes after images + # Sort the content. Anthropic models prefer that text comes after images. message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1) for idx, part in enumerate(message.content): if part["type"] == "text": content.append({"type": "text", "text": part["text"]}) elif part["type"] == "image_url": - b64_image, media_type = get_image_from_url(part["image_url"]["url"], type="b64") + image = get_image_from_url(part["image_url"]["url"], type="b64") + # Prefix each image with text block enumerating the image number + # This helps the model reference the image in its response. Recommended by Anthropic content.extend( [ { @@ -156,7 +158,7 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=Non }, { "type": "image", - "source": {"type": "base64", "media_type": media_type, "data": b64_image}, + "source": {"type": "base64", "media_type": image.type, "data": image.content}, }, ] ) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index a4041a94..964fe80b 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -204,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = if isinstance(message.content, list): # Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini) message.content = [ - get_image_from_url(item["image_url"]["url"])[0] if item["type"] == "image_url" else item["text"] + get_image_from_url(item["image_url"]["url"]).content if item["type"] == "image_url" else item["text"] for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1) ] elif isinstance(message.content, str): diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 943c5616..fb6d1909 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -3,6 +3,7 @@ import logging import math import mimetypes import queue +from dataclasses import dataclass from datetime import datetime from io import BytesIO from time import perf_counter @@ -317,6 +318,12 @@ def remove_json_codeblock(response: str): return response.removeprefix("```json").removesuffix("```") +@dataclass +class ImageWithType: + content: Any + type: str + + def get_image_from_url(image_url: str, type="pil"): try: response = requests.get(image_url) @@ -325,12 +332,15 @@ def get_image_from_url(image_url: str, type="pil"): # Get content type from response or infer from URL content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp" + # Convert image to desired format if type == "b64": - return base64.b64encode(response.content).decode("utf-8"), content_type + image_data = base64.b64encode(response.content).decode("utf-8") elif type == "pil": - return PIL.Image.open(BytesIO(response.content)), content_type + image_data = PIL.Image.open(BytesIO(response.content)) else: raise ValueError(f"Invalid image type: {type}") + + return ImageWithType(content=image_data, type=content_type) except requests.exceptions.RequestException as e: logger.error(f"Failed to get image from URL {image_url}: {e}") - return None, None + return ImageWithType(content=None, type=None)