mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Give Vision to Anthropic models in Khoj
This commit is contained in:
@@ -13,7 +13,10 @@ from khoj.processor.conversation.anthropic.utils import (
|
|||||||
anthropic_completion_with_backoff,
|
anthropic_completion_with_backoff,
|
||||||
format_messages_for_anthropic,
|
format_messages_for_anthropic,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
from khoj.processor.conversation.utils import (
|
||||||
|
construct_structured_message,
|
||||||
|
generate_chatml_messages_with_context,
|
||||||
|
)
|
||||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
@@ -28,6 +31,8 @@ def extract_questions_anthropic(
|
|||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
|
query_images: Optional[list[str]] = None,
|
||||||
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -69,6 +74,13 @@ def extract_questions_anthropic(
|
|||||||
text=text,
|
text=text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt = construct_structured_message(
|
||||||
|
message=prompt,
|
||||||
|
images=query_images,
|
||||||
|
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
messages = [ChatMessage(content=prompt, role="user")]
|
messages = [ChatMessage(content=prompt, role="user")]
|
||||||
|
|
||||||
response = anthropic_completion_with_backoff(
|
response = anthropic_completion_with_backoff(
|
||||||
@@ -118,7 +130,7 @@ def converse_anthropic(
|
|||||||
user_query,
|
user_query,
|
||||||
online_results: Optional[Dict[str, Dict]] = None,
|
online_results: Optional[Dict[str, Dict]] = None,
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: Optional[str] = "claude-instant-1.2",
|
model: Optional[str] = "claude-3-5-sonnet-20241022",
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
completion_func=None,
|
completion_func=None,
|
||||||
conversation_commands=[ConversationCommand.Default],
|
conversation_commands=[ConversationCommand.Default],
|
||||||
@@ -127,6 +139,8 @@ def converse_anthropic(
|
|||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_images: Optional[list[str]] = None,
|
||||||
|
vision_available: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Converse with user using Anthropic's Claude
|
Converse with user using Anthropic's Claude
|
||||||
@@ -180,6 +194,8 @@ def converse_anthropic(
|
|||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
|
query_images=query_images,
|
||||||
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -157,7 +157,11 @@ def construct_structured_message(message: str, images: list[str], model_type: st
|
|||||||
if not images or not vision_enabled:
|
if not images or not vision_enabled:
|
||||||
return message
|
return message
|
||||||
|
|
||||||
if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]:
|
if model_type in [
|
||||||
|
ChatModelOptions.ModelType.OPENAI,
|
||||||
|
ChatModelOptions.ModelType.GOOGLE,
|
||||||
|
ChatModelOptions.ModelType.ANTHROPIC,
|
||||||
|
]:
|
||||||
return [
|
return [
|
||||||
{"type": "text", "text": message},
|
{"type": "text", "text": message},
|
||||||
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
|
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
|
||||||
|
|||||||
@@ -447,11 +447,13 @@ async def extract_references_and_questions(
|
|||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
inferred_queries = extract_questions_anthropic(
|
inferred_queries = extract_questions_anthropic(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
|
query_images=query_images,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user=user,
|
user=user,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
|
|||||||
@@ -825,10 +825,13 @@ async def send_message_to_model_wrapper(
|
|||||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = conversation_config.vision_enabled
|
||||||
if not vision_available and query_images:
|
if not vision_available and query_images:
|
||||||
|
logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.")
|
||||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||||
if vision_enabled_config:
|
if vision_enabled_config:
|
||||||
conversation_config = vision_enabled_config
|
conversation_config = vision_enabled_config
|
||||||
vision_available = True
|
vision_available = True
|
||||||
|
if vision_available and query_images:
|
||||||
|
logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.")
|
||||||
|
|
||||||
subscribed = await ais_user_subscribed(user)
|
subscribed = await ais_user_subscribed(user)
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
@@ -1109,8 +1112,9 @@ def generate_chat_response(
|
|||||||
chat_response = converse_anthropic(
|
chat_response = converse_anthropic(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
q,
|
q,
|
||||||
online_results,
|
query_images=query_images,
|
||||||
meta_log,
|
online_results=online_results,
|
||||||
|
conversation_log=meta_log,
|
||||||
model=conversation_config.chat_model,
|
model=conversation_config.chat_model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
@@ -1120,6 +1124,7 @@ def generate_chat_response(
|
|||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
vision_available=vision_available,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
|||||||
Reference in New Issue
Block a user