diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index cb51abb4..b6d85726 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -11,6 +11,7 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.utils import ( anthropic_chat_completion_with_backoff, anthropic_completion_with_backoff, + format_messages_for_anthropic, ) from khoj.processor.conversation.utils import generate_chatml_messages_with_context from khoj.utils.helpers import ConversationCommand, is_none_or_empty @@ -101,17 +102,7 @@ def anthropic_send_message_to_model(messages, api_key, model): """ Send message to model """ - # Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter - system_prompt = None - - if len(messages) == 1: - messages[0].role = "user" - else: - system_prompt = "" - for message in messages.copy(): - if message.role == "system": - system_prompt += message.content - messages.remove(message) + messages, system_prompt = format_messages_for_anthropic(messages) # Get Response from GPT. Don't use response_type because Anthropic doesn't support it. return anthropic_completion_with_backoff( @@ -192,14 +183,7 @@ def converse_anthropic( model_type=ChatModelOptions.ModelType.ANTHROPIC, ) - if len(messages) > 1: - if messages[0].role == "assistant": - messages = messages[1:] - - for message in messages.copy(): - if message.role == "system": - system_prompt += message.content - messages.remove(message) + messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) logger.debug(f"Conversation Context for Claude: {truncated_messages}") diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 79ccac4e..cc020b0a 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -3,6 +3,7 @@ from threading import Thread from typing import Dict, List import anthropic +from langchain.schema import ChatMessage from tenacity import ( before_sleep_log, retry, @@ -11,7 +12,8 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator +from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url +from khoj.utils.helpers import is_none_or_empty logger = logging.getLogger(__name__) @@ -115,3 +117,49 @@ def anthropic_llm_thread( logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) finally: g.close() + + +def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=None): + """ + Format messages for Anthropic + """ + # Extract system prompt + system_prompt = system_prompt or "" + for message in messages.copy(): + if message.role == "system": + system_prompt += message.content + messages.remove(message) + system_prompt = None if is_none_or_empty(system_prompt) else system_prompt + + # Anthropic requires the first message to be a 'user' message + if len(messages) == 1: + messages[0].role = "user" + elif len(messages) > 1 and messages[0].role == "assistant": + messages = messages[1:] + + # Convert image urls to base64 encoded images in Anthropic message format + for message in messages: + if isinstance(message.content, list): + content = [] + # Sort the content as preferred if text comes after images + message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1) + for idx, part in enumerate(message.content): + if part["type"] == "text": + content.append({"type": "text", "text": part["text"]}) + elif part["type"] == "image_url": + b64_image, media_type = get_image_from_url(part["image_url"]["url"], type="b64") + content.extend( + [ + { + "type": "text", + "text": f"Image {idx + 1}:", + }, + { + "type": "image", + "source": {"type": "base64", "media_type": media_type, "data": b64_image}, + }, + ] + ) + message.content = content + + return messages, system_prompt