Update truncation logic to handle multi-part message content

This commit is contained in:
Debanjum
2025-05-10 18:41:16 -06:00
parent a337d9e4b8
commit 2694734d22
5 changed files with 207 additions and 66 deletions

View File

@@ -1,3 +1,5 @@
from copy import deepcopy
import tiktoken
from langchain.schema import ChatMessage
@@ -5,7 +7,7 @@ from khoj.processor.conversation import utils
class TestTruncateMessage:
max_prompt_size = 10
max_prompt_size = 40
model_name = "gpt-4o-mini"
encoder = tiktoken.encoding_for_model(model_name)
@@ -15,45 +17,108 @@ class TestTruncateMessage:
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_history) < 50
assert len(chat_history) > 1
assert len(chat_history) > 5
assert tokens <= self.max_prompt_size
def test_truncate_message_only_oldest_big(self):
# Arrange
chat_history = generate_chat_history(5)
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
chat_history.append(big_chat_message)
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_history) == 5
assert tokens <= self.max_prompt_size
def test_truncate_message_with_image(self):
# Arrange
image_content_item = {"type": "image_url", "image_url": {"url": "placeholder"}}
content_list = [{"type": "text", "text": f"{index}"} for index in range(100)]
content_list += [image_content_item, {"type": "text", "text": "Question?"}]
big_chat_message = ChatMessage(role="user", content=content_list)
copy_big_chat_message = deepcopy(big_chat_message)
chat_history = [big_chat_message]
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
def test_truncate_message_with_content_list(self):
# Arrange
chat_history = generate_chat_history(5)
content_list = [{"type": "text", "text": f"{index}"} for index in range(100)]
content_list += [{"type": "text", "text": "Question?"}]
big_chat_message = ChatMessage(role="user", content=content_list)
copy_big_chat_message = deepcopy(big_chat_message)
chat_history.insert(0, big_chat_message)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert (
len(chat_history) == 1
), "Only most recent message should be present as it itself is larger than context size"
assert len(truncated_chat_history[0].content) < len(
copy_big_chat_message.content
), "message content list should be modified"
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
def test_truncate_message_first_large(self):
# Arrange
chat_history = generate_chat_history(5)
big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?")
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_history) == 1
assert truncated_chat_history[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size
assert (
len(chat_history) == 1
), "Only most recent message should be present as it itself is larger than context size"
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
def test_truncate_message_last_large(self):
def test_truncate_message_large_system_message_first(self):
# Arrange
chat_history = generate_chat_history(5)
chat_history[0].role = "system" # Mark the first message as system message
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message)
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties.
@@ -62,46 +127,52 @@ class TestTruncateMessage:
) # Because the system_prompt is popped off from the chat_messages list
assert len(truncated_chat_history) < 10
assert len(truncated_chat_history) > 1
assert truncated_chat_history[0] != copy_big_chat_message
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
def test_truncate_single_large_non_system_message(self):
# Arrange
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages])
# Act
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
assert len(chat_messages) == 1
assert truncated_chat_history[0] != copy_big_chat_message
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
assert (
len(chat_messages) == 1
), "Only most recent message should be present as it itself is larger than context size"
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
def test_truncate_single_large_question(self):
# Arrange
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
big_chat_message_content = [{"type": "text", "text": " ".join(["hi"] * (self.max_prompt_size + 1))}]
big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages])
# Act
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
assert len(chat_messages) == 1
assert truncated_chat_history[0] != copy_big_chat_message
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
assert (
len(chat_messages) == 1
), "Only most recent message should be present as it itself is larger than context size"
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
def test_load_complex_raw_json_string():
@@ -116,12 +187,12 @@ def test_load_complex_raw_json_string():
assert parsed_json == expeced_json
def generate_content(count):
return " ".join([f"{index}" for index, _ in enumerate(range(count))])
def generate_content(count, suffix=""):
return [{"type": "text", "text": " ".join([f"{index}" for index, _ in enumerate(range(count))]) + "\n" + suffix}]
def generate_chat_history(count):
return [
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}")
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=[{"type": "text", "text": f"{index}"}])
for index, _ in enumerate(range(count))
]