mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Update truncation logic to handle multi-part message content
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
@@ -5,7 +7,7 @@ from khoj.processor.conversation import utils
|
||||
|
||||
|
||||
class TestTruncateMessage:
|
||||
max_prompt_size = 10
|
||||
max_prompt_size = 40
|
||||
model_name = "gpt-4o-mini"
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
|
||||
@@ -15,45 +17,108 @@ class TestTruncateMessage:
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_history) < 50
|
||||
assert len(chat_history) > 1
|
||||
assert len(chat_history) > 5
|
||||
assert tokens <= self.max_prompt_size
|
||||
|
||||
def test_truncate_message_only_oldest_big(self):
|
||||
# Arrange
|
||||
chat_history = generate_chat_history(5)
|
||||
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||
chat_history.append(big_chat_message)
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_history) == 5
|
||||
assert tokens <= self.max_prompt_size
|
||||
|
||||
def test_truncate_message_with_image(self):
|
||||
# Arrange
|
||||
image_content_item = {"type": "image_url", "image_url": {"url": "placeholder"}}
|
||||
content_list = [{"type": "text", "text": f"{index}"} for index in range(100)]
|
||||
content_list += [image_content_item, {"type": "text", "text": "Question?"}]
|
||||
big_chat_message = ChatMessage(role="user", content=content_list)
|
||||
copy_big_chat_message = deepcopy(big_chat_message)
|
||||
chat_history = [big_chat_message]
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
|
||||
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
|
||||
def test_truncate_message_with_content_list(self):
|
||||
# Arrange
|
||||
chat_history = generate_chat_history(5)
|
||||
content_list = [{"type": "text", "text": f"{index}"} for index in range(100)]
|
||||
content_list += [{"type": "text", "text": "Question?"}]
|
||||
big_chat_message = ChatMessage(role="user", content=content_list)
|
||||
copy_big_chat_message = deepcopy(big_chat_message)
|
||||
chat_history.insert(0, big_chat_message)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert (
|
||||
len(chat_history) == 1
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert len(truncated_chat_history[0].content) < len(
|
||||
copy_big_chat_message.content
|
||||
), "message content list should be modified"
|
||||
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
|
||||
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
|
||||
def test_truncate_message_first_large(self):
|
||||
# Arrange
|
||||
chat_history = generate_chat_history(5)
|
||||
big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?")
|
||||
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_history.insert(0, big_chat_message)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_history) == 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
assert tokens <= self.max_prompt_size
|
||||
assert (
|
||||
len(chat_history) == 1
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||
|
||||
def test_truncate_message_last_large(self):
|
||||
def test_truncate_message_large_system_message_first(self):
|
||||
# Arrange
|
||||
chat_history = generate_chat_history(5)
|
||||
chat_history[0].role = "system" # Mark the first message as system message
|
||||
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
|
||||
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
|
||||
chat_history.insert(0, big_chat_message)
|
||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
||||
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties.
|
||||
@@ -62,46 +127,52 @@ class TestTruncateMessage:
|
||||
) # Because the system_prompt is popped off from the chat_messages list
|
||||
assert len(truncated_chat_history) < 10
|
||||
assert len(truncated_chat_history) > 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
assert initial_tokens > self.max_prompt_size
|
||||
assert final_tokens <= self.max_prompt_size
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
|
||||
|
||||
def test_truncate_single_large_non_system_message(self):
|
||||
# Arrange
|
||||
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
|
||||
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_messages = [big_chat_message]
|
||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert initial_tokens > self.max_prompt_size
|
||||
assert final_tokens <= self.max_prompt_size
|
||||
assert len(chat_messages) == 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
|
||||
assert (
|
||||
len(chat_messages) == 1
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||
|
||||
def test_truncate_single_large_question(self):
|
||||
# Arrange
|
||||
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
|
||||
big_chat_message_content = [{"type": "text", "text": " ".join(["hi"] * (self.max_prompt_size + 1))}]
|
||||
big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_messages = [big_chat_message]
|
||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages])
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert initial_tokens > self.max_prompt_size
|
||||
assert final_tokens <= self.max_prompt_size
|
||||
assert len(chat_messages) == 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
|
||||
assert (
|
||||
len(chat_messages) == 1
|
||||
), "Only most recent message should be present as it itself is larger than context size"
|
||||
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||
|
||||
|
||||
def test_load_complex_raw_json_string():
|
||||
@@ -116,12 +187,12 @@ def test_load_complex_raw_json_string():
|
||||
assert parsed_json == expeced_json
|
||||
|
||||
|
||||
def generate_content(count):
|
||||
return " ".join([f"{index}" for index, _ in enumerate(range(count))])
|
||||
def generate_content(count, suffix=""):
|
||||
return [{"type": "text", "text": " ".join([f"{index}" for index, _ in enumerate(range(count))]) + "\n" + suffix}]
|
||||
|
||||
|
||||
def generate_chat_history(count):
|
||||
return [
|
||||
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}")
|
||||
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=[{"type": "text", "text": f"{index}"}])
|
||||
for index, _ in enumerate(range(count))
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user