Handle truncation when single long non-system chat message

Previously was assuming the system prompt is being always passed as
the first message. So expected there to be at least 2 messages in logs.

This broke chat actors querying with single long non system message.

A more robust way to extract system prompt is via the message role
instead
This commit is contained in:
Debanjum Singh Solanky
2024-03-15 14:52:29 +05:30
parent ec0c35b7ed
commit ecddf98430
2 changed files with 67 additions and 29 deletions

View File

@@ -199,19 +199,26 @@ def truncate_messages(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)
system_message = messages.pop()
assert type(system_message.content) == str
system_message_tokens = len(encoder.encode(system_message.content))
# Extract system message from messages
system_message = None
for idx, message in enumerate(messages):
if message.role == "system":
system_message = messages.pop(idx)
break
system_message_tokens = (
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
)
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Drop older messages until under max supported prompt size by model
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
messages.pop()
assert type(system_message.content) == str
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size:
assert type(system_message.content) == str
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
original_question = f"\n{original_question}"
@@ -223,7 +230,7 @@ def truncate_messages(
)
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
return messages + [system_message]
return messages + [system_message] if system_message else messages
def reciprocal_conversation_to_chatml(message_pair):

View File

@@ -19,49 +19,80 @@ class TestTruncateMessage:
encoder = tiktoken.encoding_for_model(model_name)
def test_truncate_message_all_small(self):
chat_messages = ChatMessageFactory.build_batch(500)
# Arrange
chat_history = ChatMessageFactory.build_batch(500)
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_messages) < 500
assert len(chat_messages) > 1
assert len(chat_history) < 500
assert len(chat_history) > 1
assert tokens <= self.max_prompt_size
def test_truncate_message_first_large(self):
chat_messages = ChatMessageFactory.build_batch(25)
# Arrange
chat_history = ChatMessageFactory.build_batch(25)
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
chat_messages.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
chat_history.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_messages) == 1
assert prompt[0] != copy_big_chat_message
assert len(chat_history) == 1
assert truncated_chat_history[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size
def test_truncate_message_last_large(self):
chat_messages = ChatMessageFactory.build_batch(25)
# Arrange
chat_history = ChatMessageFactory.build_batch(25)
chat_history[0].role = "system" # Mark the first message as system message
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
chat_messages.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
chat_history.insert(0, big_chat_message)
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties.
assert len(prompt) == (
len(chat_messages) + 1
assert len(truncated_chat_history) == (
len(chat_history) + 1
) # Because the system_prompt is popped off from the chat_messages lsit
assert len(prompt) < 26
assert len(prompt) > 1
assert prompt[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size
assert len(truncated_chat_history) < 26
assert len(truncated_chat_history) > 1
assert truncated_chat_history[0] != copy_big_chat_message
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
def test_truncate_single_large_non_system_message(self):
# Arrange
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
big_chat_message.role = "user"
copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
# Act
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
assert len(chat_messages) == 1
assert truncated_chat_history[0] != copy_big_chat_message