Handle truncation when single long non-system chat message

Previously was assuming the system prompt is being always passed as
the first message. So expected there to be at least 2 messages in logs.

This broke chat actors querying with single long non system message.

A more robust way to extract system prompt is via the message role
instead
This commit is contained in:
Debanjum Singh Solanky
2024-03-15 14:52:29 +05:30
parent ec0c35b7ed
commit ecddf98430
2 changed files with 67 additions and 29 deletions

View File

@@ -199,19 +199,26 @@ def truncate_messages(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
) )
system_message = messages.pop() # Extract system message from messages
assert type(system_message.content) == str system_message = None
system_message_tokens = len(encoder.encode(system_message.content)) for idx, message in enumerate(messages):
if message.role == "system":
system_message = messages.pop(idx)
break
system_message_tokens = (
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
)
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Drop older messages until under max supported prompt size by model
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1: while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
messages.pop() messages.pop()
assert type(system_message.content) == str
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Truncate current message if still over max supported prompt size by model # Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size: if (tokens + system_message_tokens) > max_prompt_size:
assert type(system_message.content) == str
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else "" current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else "" original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
original_question = f"\n{original_question}" original_question = f"\n{original_question}"
@@ -223,7 +230,7 @@ def truncate_messages(
) )
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)] messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
return messages + [system_message] return messages + [system_message] if system_message else messages
def reciprocal_conversation_to_chatml(message_pair): def reciprocal_conversation_to_chatml(message_pair):

View File

@@ -19,49 +19,80 @@ class TestTruncateMessage:
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
def test_truncate_message_all_small(self): def test_truncate_message_all_small(self):
chat_messages = ChatMessageFactory.build_batch(500) # Arrange
chat_history = ChatMessageFactory.build_batch(500)
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) # Act
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert len(chat_messages) < 500 assert len(chat_history) < 500
assert len(chat_messages) > 1 assert len(chat_history) > 1
assert tokens <= self.max_prompt_size assert tokens <= self.max_prompt_size
def test_truncate_message_first_large(self): def test_truncate_message_first_large(self):
chat_messages = ChatMessageFactory.build_batch(25) # Arrange
chat_history = ChatMessageFactory.build_batch(25)
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000)) big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?" big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_messages.insert(0, big_chat_message) chat_history.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) # Act
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert len(chat_messages) == 1 assert len(chat_history) == 1
assert prompt[0] != copy_big_chat_message assert truncated_chat_history[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size assert tokens <= self.max_prompt_size
def test_truncate_message_last_large(self): def test_truncate_message_last_large(self):
chat_messages = ChatMessageFactory.build_batch(25) # Arrange
chat_history = ChatMessageFactory.build_batch(25)
chat_history[0].role = "system" # Mark the first message as system message
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000)) big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?" big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_messages.insert(0, big_chat_message) chat_history.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) # Act
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties. # The original object has been modified. Verify certain properties.
assert len(prompt) == ( assert len(truncated_chat_history) == (
len(chat_messages) + 1 len(chat_history) + 1
) # Because the system_prompt is popped off from the chat_messages lsit ) # Because the system_prompt is popped off from the chat_messages lsit
assert len(prompt) < 26 assert len(truncated_chat_history) < 26
assert len(prompt) > 1 assert len(truncated_chat_history) > 1
assert prompt[0] != copy_big_chat_message assert truncated_chat_history[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
def test_truncate_single_large_non_system_message(self):
# Arrange
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
big_chat_message.role = "user"
copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
# Act
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
assert len(chat_messages) == 1
assert truncated_chat_history[0] != copy_big_chat_message