Merge branch 'features/customize-chat-with-agents' of github.com:khoj-ai/khoj into features/add-agents-ui

This commit is contained in:
sabaimran
2024-03-23 11:20:57 +05:30
14 changed files with 162 additions and 76 deletions

View File

@@ -19,49 +19,80 @@ class TestTruncateMessage:
encoder = tiktoken.encoding_for_model(model_name)
def test_truncate_message_all_small(self):
chat_messages = ChatMessageFactory.build_batch(500)
# Arrange
chat_history = ChatMessageFactory.build_batch(500)
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_messages) < 500
assert len(chat_messages) > 1
assert len(chat_history) < 500
assert len(chat_history) > 1
assert tokens <= self.max_prompt_size
def test_truncate_message_first_large(self):
chat_messages = ChatMessageFactory.build_batch(25)
# Arrange
chat_history = ChatMessageFactory.build_batch(25)
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
chat_messages.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
chat_history.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_messages) == 1
assert prompt[0] != copy_big_chat_message
assert len(chat_history) == 1
assert truncated_chat_history[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size
def test_truncate_message_last_large(self):
chat_messages = ChatMessageFactory.build_batch(25)
# Arrange
chat_history = ChatMessageFactory.build_batch(25)
chat_history[0].role = "system" # Mark the first message as system message
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
chat_messages.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
chat_history.insert(0, big_chat_message)
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties.
assert len(prompt) == (
len(chat_messages) + 1
assert len(truncated_chat_history) == (
len(chat_history) + 1
) # Because the system_prompt is popped off from the chat_messages lsit
assert len(prompt) < 26
assert len(prompt) > 1
assert prompt[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size
assert len(truncated_chat_history) < 26
assert len(truncated_chat_history) > 1
assert truncated_chat_history[0] != copy_big_chat_message
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
def test_truncate_single_large_non_system_message(self):
# Arrange
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
big_chat_message.role = "user"
copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
# Act
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
assert len(chat_messages) == 1
assert truncated_chat_history[0] != copy_big_chat_message