mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Rename truncate messages method and update unit tests to simplify assertion logic
This commit is contained in:
@@ -97,13 +97,13 @@ def generate_chatml_messages_with_context(
|
|||||||
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
||||||
|
|
||||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||||
messages = truncate_message(messages, max_prompt_size[model_name], model_name)
|
messages = truncate_messages(messages, max_prompt_size[model_name], model_name)
|
||||||
|
|
||||||
# Return message in chronological order
|
# Return message in chronological order
|
||||||
return messages[::-1]
|
return messages[::-1]
|
||||||
|
|
||||||
|
|
||||||
def truncate_message(messages, max_prompt_size, model_name):
|
def truncate_messages(messages, max_prompt_size, model_name):
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||||
|
|||||||
@@ -19,18 +19,15 @@ class TestTruncateMessage:
|
|||||||
|
|
||||||
def test_truncate_message_all_small(self):
|
def test_truncate_message_all_small(self):
|
||||||
chat_messages = ChatMessageFactory.build_batch(500)
|
chat_messages = ChatMessageFactory.build_batch(500)
|
||||||
assert len(chat_messages) == 500
|
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
assert tokens > self.max_prompt_size
|
|
||||||
|
|
||||||
prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name)
|
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||||
|
|
||||||
# The original object has been modified. Verify certain properties
|
# The original object has been modified. Verify certain properties
|
||||||
assert len(chat_messages) < 500
|
assert len(chat_messages) < 500
|
||||||
assert len(chat_messages) > 1
|
assert len(chat_messages) > 1
|
||||||
assert prompt == chat_messages
|
assert prompt == chat_messages
|
||||||
|
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
|
||||||
assert tokens <= self.max_prompt_size
|
assert tokens <= self.max_prompt_size
|
||||||
|
|
||||||
def test_truncate_message_first_large(self):
|
def test_truncate_message_first_large(self):
|
||||||
@@ -39,18 +36,14 @@ class TestTruncateMessage:
|
|||||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
chat_messages.insert(0, big_chat_message)
|
chat_messages.insert(0, big_chat_message)
|
||||||
assert len(chat_messages) == 26
|
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
assert tokens > self.max_prompt_size
|
|
||||||
|
|
||||||
prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name)
|
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||||
|
|
||||||
# The original object has been modified. Verify certain properties
|
# The original object has been modified. Verify certain properties
|
||||||
assert len(chat_messages) < 26
|
|
||||||
assert len(chat_messages) == 1
|
assert len(chat_messages) == 1
|
||||||
assert prompt[0] != copy_big_chat_message
|
assert prompt[0] != copy_big_chat_message
|
||||||
|
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
|
||||||
assert tokens <= self.max_prompt_size
|
assert tokens <= self.max_prompt_size
|
||||||
|
|
||||||
def test_truncate_message_last_large(self):
|
def test_truncate_message_last_large(self):
|
||||||
@@ -60,16 +53,13 @@ class TestTruncateMessage:
|
|||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
|
|
||||||
chat_messages.append(big_chat_message)
|
chat_messages.append(big_chat_message)
|
||||||
assert len(chat_messages) == 26
|
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
assert tokens > self.max_prompt_size
|
|
||||||
|
|
||||||
prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name)
|
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||||
|
|
||||||
# The original object has been modified. Verify certain properties
|
# The original object has been modified. Verify certain properties
|
||||||
assert len(chat_messages) < 26
|
assert len(chat_messages) < 26
|
||||||
assert len(chat_messages) > 1
|
assert len(chat_messages) > 1
|
||||||
assert prompt[0] != copy_big_chat_message
|
assert prompt[0] != copy_big_chat_message
|
||||||
|
assert tokens <= self.max_prompt_size
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
|
||||||
assert tokens < self.max_prompt_size
|
|
||||||
|
|||||||
Reference in New Issue
Block a user