diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b384ad7a..15a4970e 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -199,19 +199,26 @@ def truncate_messages( f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." ) - system_message = messages.pop() - assert type(system_message.content) == str - system_message_tokens = len(encoder.encode(system_message.content)) + # Extract system message from messages + system_message = None + for idx, message in enumerate(messages): + if message.role == "system": + system_message = messages.pop(idx) + break + + system_message_tokens = ( + len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0 + ) tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) + + # Drop older messages until under max supported prompt size by model while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1: messages.pop() - assert type(system_message.content) == str tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) # Truncate current message if still over max supported prompt size by model if (tokens + system_message_tokens) > max_prompt_size: - assert type(system_message.content) == str current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else "" original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else "" original_question = f"\n{original_question}" @@ -223,7 +230,7 @@ def truncate_messages( ) messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)] - return messages + [system_message] + return messages + [system_message] if system_message else messages def reciprocal_conversation_to_chatml(message_pair): diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index 52db0002..bc8c5315 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -19,49 +19,80 @@ class TestTruncateMessage: encoder = tiktoken.encoding_for_model(model_name) def test_truncate_message_all_small(self): - chat_messages = ChatMessageFactory.build_batch(500) + # Arrange + chat_history = ChatMessageFactory.build_batch(500) - prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) - tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) + # Act + truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + # Assert # The original object has been modified. Verify certain properties - assert len(chat_messages) < 500 - assert len(chat_messages) > 1 + assert len(chat_history) < 500 + assert len(chat_history) > 1 assert tokens <= self.max_prompt_size def test_truncate_message_first_large(self): - chat_messages = ChatMessageFactory.build_batch(25) + # Arrange + chat_history = ChatMessageFactory.build_batch(25) big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000)) big_chat_message.content = big_chat_message.content + "\n" + "Question?" copy_big_chat_message = big_chat_message.copy() - chat_messages.insert(0, big_chat_message) - tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + chat_history.insert(0, big_chat_message) + tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) - prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) - tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) + # Act + truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) + tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + # Assert # The original object has been modified. Verify certain properties - assert len(chat_messages) == 1 - assert prompt[0] != copy_big_chat_message + assert len(chat_history) == 1 + assert truncated_chat_history[0] != copy_big_chat_message assert tokens <= self.max_prompt_size def test_truncate_message_last_large(self): - chat_messages = ChatMessageFactory.build_batch(25) + # Arrange + chat_history = ChatMessageFactory.build_batch(25) + chat_history[0].role = "system" # Mark the first message as system message big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000)) big_chat_message.content = big_chat_message.content + "\n" + "Question?" copy_big_chat_message = big_chat_message.copy() - chat_messages.insert(0, big_chat_message) - tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + chat_history.insert(0, big_chat_message) + initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) - prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) - tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) + # Act + truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) + final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + # Assert # The original object has been modified. Verify certain properties. - assert len(prompt) == ( - len(chat_messages) + 1 + assert len(truncated_chat_history) == ( + len(chat_history) + 1 ) # Because the system_prompt is popped off from the chat_messages lsit - assert len(prompt) < 26 - assert len(prompt) > 1 - assert prompt[0] != copy_big_chat_message - assert tokens <= self.max_prompt_size + assert len(truncated_chat_history) < 26 + assert len(truncated_chat_history) > 1 + assert truncated_chat_history[0] != copy_big_chat_message + assert initial_tokens > self.max_prompt_size + assert final_tokens <= self.max_prompt_size + + def test_truncate_single_large_non_system_message(self): + # Arrange + big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000)) + big_chat_message.content = big_chat_message.content + "\n" + "Question?" + big_chat_message.role = "user" + copy_big_chat_message = big_chat_message.copy() + chat_messages = [big_chat_message] + initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + + # Act + truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) + final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + + # Assert + # The original object has been modified. Verify certain properties + assert initial_tokens > self.max_prompt_size + assert final_tokens <= self.max_prompt_size + assert len(chat_messages) == 1 + assert truncated_chat_history[0] != copy_big_chat_message