From 5d5ebcbf7ca9377fced10d902c5072203a2004e7 Mon Sep 17 00:00:00 2001 From: Saba Date: Tue, 6 Jun 2023 23:25:43 -0700 Subject: [PATCH] Rename truncate messages method and update unit tests to simplify assertion logic --- src/khoj/processor/conversation/utils.py | 4 ++-- tests/test_conversation_utils.py | 24 +++++++----------------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index a63a09c0..a3901d02 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -97,13 +97,13 @@ def generate_chatml_messages_with_context( messages = user_chatml_message + rest_backnforths + system_chatml_message # Truncate oldest messages from conversation history until under max supported prompt size by model - messages = truncate_message(messages, max_prompt_size[model_name], model_name) + messages = truncate_messages(messages, max_prompt_size[model_name], model_name) # Return message in chronological order return messages[::-1] -def truncate_message(messages, max_prompt_size, model_name): +def truncate_messages(messages, max_prompt_size, model_name): """Truncate messages to fit within max prompt size supported by model""" encoder = tiktoken.encoding_for_model(model_name) tokens = sum([len(encoder.encode(message.content)) for message in messages]) diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index 43f68884..06a507c5 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -19,18 +19,15 @@ class TestTruncateMessage: def test_truncate_message_all_small(self): chat_messages = ChatMessageFactory.build_batch(500) - assert len(chat_messages) == 500 tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) - assert tokens > self.max_prompt_size - prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name) + prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) # The original object has been modified. Verify certain properties assert len(chat_messages) < 500 assert len(chat_messages) > 1 assert prompt == chat_messages - - tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) assert tokens <= self.max_prompt_size def test_truncate_message_first_large(self): @@ -39,18 +36,14 @@ class TestTruncateMessage: big_chat_message.content = big_chat_message.content + "\n" + "Question?" copy_big_chat_message = big_chat_message.copy() chat_messages.insert(0, big_chat_message) - assert len(chat_messages) == 26 tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) - assert tokens > self.max_prompt_size - prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name) + prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) # The original object has been modified. Verify certain properties - assert len(chat_messages) < 26 assert len(chat_messages) == 1 assert prompt[0] != copy_big_chat_message - - tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) assert tokens <= self.max_prompt_size def test_truncate_message_last_large(self): @@ -60,16 +53,13 @@ class TestTruncateMessage: copy_big_chat_message = big_chat_message.copy() chat_messages.append(big_chat_message) - assert len(chat_messages) == 26 tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) - assert tokens > self.max_prompt_size - prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name) + prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) # The original object has been modified. Verify certain properties assert len(chat_messages) < 26 assert len(chat_messages) > 1 assert prompt[0] != copy_big_chat_message - - tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) - assert tokens < self.max_prompt_size + assert tokens <= self.max_prompt_size