Reuse logic to format messages for chat with anthropic models

This commit is contained in:
Debanjum Singh Solanky
2024-10-23 03:57:55 -07:00
parent 82eac5a043
commit 6fd50a5956
2 changed files with 52 additions and 20 deletions

View File

@@ -11,6 +11,7 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
anthropic_completion_with_backoff,
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
@@ -101,17 +102,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
"""
Send message to model
"""
# Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter
system_prompt = None
if len(messages) == 1:
messages[0].role = "user"
else:
system_prompt = ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
messages, system_prompt = format_messages_for_anthropic(messages)
# Get Response from GPT. Don't use response_type because Anthropic doesn't support it.
return anthropic_completion_with_backoff(
@@ -192,14 +183,7 @@ def converse_anthropic(
model_type=ChatModelOptions.ModelType.ANTHROPIC,
)
if len(messages) > 1:
if messages[0].role == "assistant":
messages = messages[1:]
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for Claude: {truncated_messages}")

View File

@@ -3,6 +3,7 @@ from threading import Thread
from typing import Dict, List
import anthropic
from langchain.schema import ChatMessage
from tenacity import (
before_sleep_log,
retry,
@@ -11,7 +12,8 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import ThreadedGenerator
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
from khoj.utils.helpers import is_none_or_empty
logger = logging.getLogger(__name__)
@@ -115,3 +117,49 @@ def anthropic_llm_thread(
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
finally:
g.close()
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=None):
"""
Format messages for Anthropic
"""
# Extract system prompt
system_prompt = system_prompt or ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
# Anthropic requires the first message to be a 'user' message
if len(messages) == 1:
messages[0].role = "user"
elif len(messages) > 1 and messages[0].role == "assistant":
messages = messages[1:]
# Convert image urls to base64 encoded images in Anthropic message format
for message in messages:
if isinstance(message.content, list):
content = []
# Sort the content as preferred if text comes after images
message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1)
for idx, part in enumerate(message.content):
if part["type"] == "text":
content.append({"type": "text", "text": part["text"]})
elif part["type"] == "image_url":
b64_image, media_type = get_image_from_url(part["image_url"]["url"], type="b64")
content.extend(
[
{
"type": "text",
"text": f"Image {idx + 1}:",
},
{
"type": "image",
"source": {"type": "base64", "media_type": media_type, "data": b64_image},
},
]
)
message.content = content
return messages, system_prompt