Rename truncate messages method and update unit tests to simplify assertion logic

This commit is contained in:
Saba
2023-06-06 23:25:43 -07:00
parent 7119ed0849
commit 5d5ebcbf7c
2 changed files with 9 additions and 19 deletions

View File

@@ -97,13 +97,13 @@ def generate_chatml_messages_with_context(
messages = user_chatml_message + rest_backnforths + system_chatml_message messages = user_chatml_message + rest_backnforths + system_chatml_message
# Truncate oldest messages from conversation history until under max supported prompt size by model # Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_message(messages, max_prompt_size[model_name], model_name) messages = truncate_messages(messages, max_prompt_size[model_name], model_name)
# Return message in chronological order # Return message in chronological order
return messages[::-1] return messages[::-1]
def truncate_message(messages, max_prompt_size, model_name): def truncate_messages(messages, max_prompt_size, model_name):
"""Truncate messages to fit within max prompt size supported by model""" """Truncate messages to fit within max prompt size supported by model"""
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
tokens = sum([len(encoder.encode(message.content)) for message in messages]) tokens = sum([len(encoder.encode(message.content)) for message in messages])

View File

@@ -19,18 +19,15 @@ class TestTruncateMessage:
def test_truncate_message_all_small(self): def test_truncate_message_all_small(self):
chat_messages = ChatMessageFactory.build_batch(500) chat_messages = ChatMessageFactory.build_batch(500)
assert len(chat_messages) == 500
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
assert tokens > self.max_prompt_size
prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name) prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert len(chat_messages) < 500 assert len(chat_messages) < 500
assert len(chat_messages) > 1 assert len(chat_messages) > 1
assert prompt == chat_messages assert prompt == chat_messages
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
assert tokens <= self.max_prompt_size assert tokens <= self.max_prompt_size
def test_truncate_message_first_large(self): def test_truncate_message_first_large(self):
@@ -39,18 +36,14 @@ class TestTruncateMessage:
big_chat_message.content = big_chat_message.content + "\n" + "Question?" big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_messages.insert(0, big_chat_message) chat_messages.insert(0, big_chat_message)
assert len(chat_messages) == 26
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
assert tokens > self.max_prompt_size
prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name) prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert len(chat_messages) < 26
assert len(chat_messages) == 1 assert len(chat_messages) == 1
assert prompt[0] != copy_big_chat_message assert prompt[0] != copy_big_chat_message
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
assert tokens <= self.max_prompt_size assert tokens <= self.max_prompt_size
def test_truncate_message_last_large(self): def test_truncate_message_last_large(self):
@@ -60,16 +53,13 @@ class TestTruncateMessage:
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_messages.append(big_chat_message) chat_messages.append(big_chat_message)
assert len(chat_messages) == 26
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
assert tokens > self.max_prompt_size
prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name) prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert len(chat_messages) < 26 assert len(chat_messages) < 26
assert len(chat_messages) > 1 assert len(chat_messages) > 1
assert prompt[0] != copy_big_chat_message assert prompt[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
assert tokens < self.max_prompt_size