mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 21:29:11 +00:00
Handle truncation when single long non-system chat message
Previously was assuming the system prompt is being always passed as the first message. So expected there to be at least 2 messages in logs. This broke chat actors querying with single long non system message. A more robust way to extract system prompt is via the message role instead
This commit is contained in:
@@ -199,19 +199,26 @@ def truncate_messages(
|
|||||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||||
)
|
)
|
||||||
|
|
||||||
system_message = messages.pop()
|
# Extract system message from messages
|
||||||
assert type(system_message.content) == str
|
system_message = None
|
||||||
system_message_tokens = len(encoder.encode(system_message.content))
|
for idx, message in enumerate(messages):
|
||||||
|
if message.role == "system":
|
||||||
|
system_message = messages.pop(idx)
|
||||||
|
break
|
||||||
|
|
||||||
|
system_message_tokens = (
|
||||||
|
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
||||||
|
)
|
||||||
|
|
||||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||||
|
|
||||||
|
# Drop older messages until under max supported prompt size by model
|
||||||
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
||||||
messages.pop()
|
messages.pop()
|
||||||
assert type(system_message.content) == str
|
|
||||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
||||||
|
|
||||||
# Truncate current message if still over max supported prompt size by model
|
# Truncate current message if still over max supported prompt size by model
|
||||||
if (tokens + system_message_tokens) > max_prompt_size:
|
if (tokens + system_message_tokens) > max_prompt_size:
|
||||||
assert type(system_message.content) == str
|
|
||||||
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
|
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
|
||||||
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
||||||
original_question = f"\n{original_question}"
|
original_question = f"\n{original_question}"
|
||||||
@@ -223,7 +230,7 @@ def truncate_messages(
|
|||||||
)
|
)
|
||||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
||||||
|
|
||||||
return messages + [system_message]
|
return messages + [system_message] if system_message else messages
|
||||||
|
|
||||||
|
|
||||||
def reciprocal_conversation_to_chatml(message_pair):
|
def reciprocal_conversation_to_chatml(message_pair):
|
||||||
|
|||||||
@@ -19,49 +19,80 @@ class TestTruncateMessage:
|
|||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
|
|
||||||
def test_truncate_message_all_small(self):
|
def test_truncate_message_all_small(self):
|
||||||
chat_messages = ChatMessageFactory.build_batch(500)
|
# Arrange
|
||||||
|
chat_history = ChatMessageFactory.build_batch(500)
|
||||||
|
|
||||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
# Act
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties
|
# The original object has been modified. Verify certain properties
|
||||||
assert len(chat_messages) < 500
|
assert len(chat_history) < 500
|
||||||
assert len(chat_messages) > 1
|
assert len(chat_history) > 1
|
||||||
assert tokens <= self.max_prompt_size
|
assert tokens <= self.max_prompt_size
|
||||||
|
|
||||||
def test_truncate_message_first_large(self):
|
def test_truncate_message_first_large(self):
|
||||||
chat_messages = ChatMessageFactory.build_batch(25)
|
# Arrange
|
||||||
|
chat_history = ChatMessageFactory.build_batch(25)
|
||||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
|
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
|
||||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
chat_messages.insert(0, big_chat_message)
|
chat_history.insert(0, big_chat_message)
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
||||||
|
|
||||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
# Act
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties
|
# The original object has been modified. Verify certain properties
|
||||||
assert len(chat_messages) == 1
|
assert len(chat_history) == 1
|
||||||
assert prompt[0] != copy_big_chat_message
|
assert truncated_chat_history[0] != copy_big_chat_message
|
||||||
assert tokens <= self.max_prompt_size
|
assert tokens <= self.max_prompt_size
|
||||||
|
|
||||||
def test_truncate_message_last_large(self):
|
def test_truncate_message_last_large(self):
|
||||||
chat_messages = ChatMessageFactory.build_batch(25)
|
# Arrange
|
||||||
|
chat_history = ChatMessageFactory.build_batch(25)
|
||||||
|
chat_history[0].role = "system" # Mark the first message as system message
|
||||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
||||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
|
|
||||||
chat_messages.insert(0, big_chat_message)
|
chat_history.insert(0, big_chat_message)
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
||||||
|
|
||||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
# Act
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
|
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties.
|
# The original object has been modified. Verify certain properties.
|
||||||
assert len(prompt) == (
|
assert len(truncated_chat_history) == (
|
||||||
len(chat_messages) + 1
|
len(chat_history) + 1
|
||||||
) # Because the system_prompt is popped off from the chat_messages lsit
|
) # Because the system_prompt is popped off from the chat_messages lsit
|
||||||
assert len(prompt) < 26
|
assert len(truncated_chat_history) < 26
|
||||||
assert len(prompt) > 1
|
assert len(truncated_chat_history) > 1
|
||||||
assert prompt[0] != copy_big_chat_message
|
assert truncated_chat_history[0] != copy_big_chat_message
|
||||||
assert tokens <= self.max_prompt_size
|
assert initial_tokens > self.max_prompt_size
|
||||||
|
assert final_tokens <= self.max_prompt_size
|
||||||
|
|
||||||
|
def test_truncate_single_large_non_system_message(self):
|
||||||
|
# Arrange
|
||||||
|
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
|
||||||
|
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||||
|
big_chat_message.role = "user"
|
||||||
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
|
chat_messages = [big_chat_message]
|
||||||
|
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
|
|
||||||
|
# Act
|
||||||
|
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||||
|
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# The original object has been modified. Verify certain properties
|
||||||
|
assert initial_tokens > self.max_prompt_size
|
||||||
|
assert final_tokens <= self.max_prompt_size
|
||||||
|
assert len(chat_messages) == 1
|
||||||
|
assert truncated_chat_history[0] != copy_big_chat_message
|
||||||
|
|||||||
Reference in New Issue
Block a user