Reduce logical complexity of constructing context from chat history

- Process chat history in default order instead of processing it in
  reverse. Improve legibility of context construction for minor
  performance hit in dropping message from front of list.
- Handle multiple system messages by collating them into list
- Remove logic to drop system role for gemma-2, o1 models. Better to
  make code more readable than support old models.
This commit is contained in:
Debanjum
2025-08-25 00:33:19 -07:00
parent 1e81b51abc
commit 8a16f5a2af
2 changed files with 60 additions and 58 deletions

View File

@@ -29,7 +29,7 @@ class TestTruncateMessage:
# Arrange
chat_history = generate_chat_history(5)
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
chat_history.append(big_chat_message)
chat_history.insert(0, big_chat_message)
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
@@ -68,7 +68,7 @@ class TestTruncateMessage:
content_list += [{"type": "text", "text": "Question?"}]
big_chat_message = ChatMessage(role="user", content=content_list)
copy_big_chat_message = deepcopy(big_chat_message)
chat_history.insert(0, big_chat_message)
chat_history.append(big_chat_message)
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
@@ -92,7 +92,7 @@ class TestTruncateMessage:
chat_history = generate_chat_history(5)
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.model_copy()
chat_history.insert(0, big_chat_message)
chat_history.append(big_chat_message)
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
@@ -104,8 +104,8 @@ class TestTruncateMessage:
assert len(chat_history) == 1, (
"Only most recent message should be present as it itself is larger than context size"
)
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert truncated_chat_history[-1] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[-1].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
@@ -116,7 +116,7 @@ class TestTruncateMessage:
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.model_copy()
chat_history.insert(0, big_chat_message)
chat_history.append(big_chat_message)
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
@@ -130,8 +130,8 @@ class TestTruncateMessage:
) # Because the system_prompt is popped off from the chat_messages list
assert len(truncated_chat_history) < 10
assert len(truncated_chat_history) > 1
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert truncated_chat_history[-1] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[-1].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"