From abad5348a06e87aca8ececdabca2bd90055bebbc Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 23 Oct 2024 04:00:44 -0700 Subject: [PATCH] Give Vision to Anthropic models in Khoj --- .../conversation/anthropic/anthropic_chat.py | 20 +++++++++++++++++-- src/khoj/processor/conversation/utils.py | 6 +++++- src/khoj/routers/api.py | 2 ++ src/khoj/routers/helpers.py | 9 +++++++-- 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index b6d85726..5e403c7b 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -13,7 +13,10 @@ from khoj.processor.conversation.anthropic.utils import ( anthropic_completion_with_backoff, format_messages_for_anthropic, ) -from khoj.processor.conversation.utils import generate_chatml_messages_with_context +from khoj.processor.conversation.utils import ( + construct_structured_message, + generate_chatml_messages_with_context, +) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -28,6 +31,8 @@ def extract_questions_anthropic( temperature=0.7, location_data: LocationData = None, user: KhojUser = None, + query_images: Optional[list[str]] = None, + vision_enabled: bool = False, personality_context: Optional[str] = None, ): """ @@ -69,6 +74,13 @@ def extract_questions_anthropic( text=text, ) + prompt = construct_structured_message( + message=prompt, + images=query_images, + model_type=ChatModelOptions.ModelType.ANTHROPIC, + vision_enabled=vision_enabled, + ) + messages = [ChatMessage(content=prompt, role="user")] response = anthropic_completion_with_backoff( @@ -118,7 +130,7 @@ def converse_anthropic( user_query, online_results: Optional[Dict[str, Dict]] = None, conversation_log={}, - model: Optional[str] = "claude-instant-1.2", + model: Optional[str] = "claude-3-5-sonnet-20241022", api_key: Optional[str] = None, completion_func=None, conversation_commands=[ConversationCommand.Default], @@ -127,6 +139,8 @@ def converse_anthropic( location_data: LocationData = None, user_name: str = None, agent: Agent = None, + query_images: Optional[list[str]] = None, + vision_available: bool = False, ): """ Converse with user using Anthropic's Claude @@ -180,6 +194,8 @@ def converse_anthropic( model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, + query_images=query_images, + vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.ANTHROPIC, ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index cb2c2ba3..943c5616 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -157,7 +157,11 @@ def construct_structured_message(message: str, images: list[str], model_type: st if not images or not vision_enabled: return message - if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]: + if model_type in [ + ChatModelOptions.ModelType.OPENAI, + ChatModelOptions.ModelType.GOOGLE, + ChatModelOptions.ModelType.ANTHROPIC, + ]: return [ {"type": "text", "text": message}, *[{"type": "image_url", "image_url": {"url": image}} for image in images], diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index c542b1f3..388024fa 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -447,11 +447,13 @@ async def extract_references_and_questions( chat_model = conversation_config.chat_model inferred_queries = extract_questions_anthropic( defiltered_query, + query_images=query_images, model=chat_model, api_key=api_key, conversation_log=meta_log, location_data=location_data, user=user, + vision_enabled=vision_enabled, personality_context=personality_context, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index c587c4bd..8425a09a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -825,10 +825,13 @@ async def send_message_to_model_wrapper( conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) vision_available = conversation_config.vision_enabled if not vision_available and query_images: + logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.") vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() if vision_enabled_config: conversation_config = vision_enabled_config vision_available = True + if vision_available and query_images: + logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.") subscribed = await ais_user_subscribed(user) chat_model = conversation_config.chat_model @@ -1109,8 +1112,9 @@ def generate_chat_response( chat_response = converse_anthropic( compiled_references, q, - online_results, - meta_log, + query_images=query_images, + online_results=online_results, + conversation_log=meta_log, model=conversation_config.chat_model, api_key=api_key, completion_func=partial_completion, @@ -1120,6 +1124,7 @@ def generate_chat_response( location_data=location_data, user_name=user_name, agent=agent, + vision_available=vision_available, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key