mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Update truncation logic to handle multi-part message content
This commit is contained in:
@@ -203,7 +203,10 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
|
|||||||
system_prompt = system_prompt or ""
|
system_prompt = system_prompt or ""
|
||||||
for message in messages.copy():
|
for message in messages.copy():
|
||||||
if message.role == "system":
|
if message.role == "system":
|
||||||
system_prompt += message.content
|
if isinstance(message.content, list):
|
||||||
|
system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"])
|
||||||
|
else:
|
||||||
|
system_prompt += message.content
|
||||||
messages.remove(message)
|
messages.remove(message)
|
||||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||||
|
|
||||||
|
|||||||
@@ -301,7 +301,10 @@ def format_messages_for_gemini(
|
|||||||
messages = deepcopy(original_messages)
|
messages = deepcopy(original_messages)
|
||||||
for message in messages.copy():
|
for message in messages.copy():
|
||||||
if message.role == "system":
|
if message.role == "system":
|
||||||
system_prompt += message.content
|
if isinstance(message.content, list):
|
||||||
|
system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"])
|
||||||
|
else:
|
||||||
|
system_prompt += message.content
|
||||||
messages.remove(message)
|
messages.remove(message)
|
||||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||||
|
|
||||||
|
|||||||
@@ -4,14 +4,12 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import queue
|
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from time import perf_counter
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
@@ -20,8 +18,9 @@ import requests
|
|||||||
import tiktoken
|
import tiktoken
|
||||||
import yaml
|
import yaml
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
from llama_cpp import LlamaTokenizer
|
||||||
from llama_cpp.llama import Llama
|
from llama_cpp.llama import Llama
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import ChatModel, ClientApplication, KhojUser
|
from khoj.database.models import ChatModel, ClientApplication, KhojUser
|
||||||
@@ -382,7 +381,7 @@ def gather_raw_query_files(
|
|||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
user_message,
|
user_message,
|
||||||
system_message=None,
|
system_message: str = None,
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model_name="gpt-4o-mini",
|
model_name="gpt-4o-mini",
|
||||||
loaded_model: Optional[Llama] = None,
|
loaded_model: Optional[Llama] = None,
|
||||||
@@ -480,7 +479,7 @@ def generate_chatml_messages_with_context(
|
|||||||
if len(chatml_messages) >= 3 * lookback_turns:
|
if len(chatml_messages) >= 3 * lookback_turns:
|
||||||
break
|
break
|
||||||
|
|
||||||
messages = []
|
messages: list[ChatMessage] = []
|
||||||
|
|
||||||
if not is_none_or_empty(generated_asset_results):
|
if not is_none_or_empty(generated_asset_results):
|
||||||
messages.append(
|
messages.append(
|
||||||
@@ -517,6 +516,11 @@ def generate_chatml_messages_with_context(
|
|||||||
if not is_none_or_empty(system_message):
|
if not is_none_or_empty(system_message):
|
||||||
messages.append(ChatMessage(content=system_message, role="system"))
|
messages.append(ChatMessage(content=system_message, role="system"))
|
||||||
|
|
||||||
|
# Normalize message content to list of chatml dictionaries
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
message.content = [{"type": "text", "text": message.content}]
|
||||||
|
|
||||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||||
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
|
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
|
||||||
|
|
||||||
@@ -524,14 +528,11 @@ def generate_chatml_messages_with_context(
|
|||||||
return messages[::-1]
|
return messages[::-1]
|
||||||
|
|
||||||
|
|
||||||
def truncate_messages(
|
def get_encoder(
|
||||||
messages: list[ChatMessage],
|
|
||||||
max_prompt_size: int,
|
|
||||||
model_name: str,
|
model_name: str,
|
||||||
loaded_model: Optional[Llama] = None,
|
loaded_model: Optional[Llama] = None,
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
) -> list[ChatMessage]:
|
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer:
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
|
||||||
default_tokenizer = "gpt-4o"
|
default_tokenizer = "gpt-4o"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -554,6 +555,48 @@ def truncate_messages(
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing."
|
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing."
|
||||||
)
|
)
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens(
|
||||||
|
message_content: str | list[str | dict],
|
||||||
|
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer | tiktoken.Encoding,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count the total number of tokens in a list of messages.
|
||||||
|
|
||||||
|
Assumes each images takes 500 tokens for approximation.
|
||||||
|
"""
|
||||||
|
if isinstance(message_content, list):
|
||||||
|
image_count = 0
|
||||||
|
message_content_parts: list[str] = []
|
||||||
|
# Collate message content into single string to ease token counting
|
||||||
|
for part in message_content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
message_content_parts.append(part["text"])
|
||||||
|
elif isinstance(part, dict) and part.get("type") == "image_url":
|
||||||
|
image_count += 1
|
||||||
|
elif isinstance(part, str):
|
||||||
|
message_content_parts.append(part)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown message type: {part}. Skipping.")
|
||||||
|
message_content = "\n".join(message_content_parts).rstrip()
|
||||||
|
return len(encoder.encode(message_content)) + image_count * 500
|
||||||
|
elif isinstance(message_content, str):
|
||||||
|
return len(encoder.encode(message_content))
|
||||||
|
else:
|
||||||
|
return len(encoder.encode(json.dumps(message_content)))
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_messages(
|
||||||
|
messages: list[ChatMessage],
|
||||||
|
max_prompt_size: int,
|
||||||
|
model_name: str,
|
||||||
|
loaded_model: Optional[Llama] = None,
|
||||||
|
tokenizer_name=None,
|
||||||
|
) -> list[ChatMessage]:
|
||||||
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
|
encoder = get_encoder(model_name, loaded_model, tokenizer_name)
|
||||||
|
|
||||||
# Extract system message from messages
|
# Extract system message from messages
|
||||||
system_message = None
|
system_message = None
|
||||||
@@ -562,35 +605,55 @@ def truncate_messages(
|
|||||||
system_message = messages.pop(idx)
|
system_message = messages.pop(idx)
|
||||||
break
|
break
|
||||||
|
|
||||||
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
|
|
||||||
system_message_tokens = (
|
|
||||||
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
|
||||||
|
|
||||||
# Drop older messages until under max supported prompt size by model
|
# Drop older messages until under max supported prompt size by model
|
||||||
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
|
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
|
||||||
while (tokens + system_message_tokens + 4 * len(messages)) > max_prompt_size and len(messages) > 1:
|
system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0
|
||||||
messages.pop()
|
tokens = sum([count_tokens(message.content, encoder) for message in messages])
|
||||||
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
|
total_tokens = tokens + system_message_tokens + 4 * len(messages)
|
||||||
|
|
||||||
|
while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1):
|
||||||
|
if len(messages[-1].content) > 1:
|
||||||
|
# The oldest content part is earlier in content list. So pop from the front.
|
||||||
|
messages[-1].content.pop(0)
|
||||||
|
else:
|
||||||
|
# The oldest message is the last one. So pop from the back.
|
||||||
|
messages.pop()
|
||||||
|
tokens = sum([count_tokens(message.content, encoder) for message in messages])
|
||||||
|
total_tokens = tokens + system_message_tokens + 4 * len(messages)
|
||||||
|
|
||||||
# Truncate current message if still over max supported prompt size by model
|
# Truncate current message if still over max supported prompt size by model
|
||||||
if (tokens + system_message_tokens) > max_prompt_size:
|
total_tokens = tokens + system_message_tokens + 4 * len(messages)
|
||||||
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
|
if total_tokens > max_prompt_size:
|
||||||
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
# At this point, a single message with a single content part of type dict should remain
|
||||||
original_question = f"\n{original_question}"
|
assert (
|
||||||
original_question_tokens = len(encoder.encode(original_question))
|
len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict)
|
||||||
|
), "Expected a single message with a single content part remaining at this point in truncation"
|
||||||
|
|
||||||
|
# Collate message content into single string to ease truncation
|
||||||
|
part = messages[0].content[0]
|
||||||
|
message_content: str = part["text"] if part["type"] == "text" else json.dumps(part)
|
||||||
|
message_role = messages[0].role
|
||||||
|
|
||||||
|
remaining_context = "\n".join(message_content.split("\n")[:-1])
|
||||||
|
original_question = "\n" + "\n".join(message_content.split("\n")[-1:])
|
||||||
|
|
||||||
|
original_question_tokens = count_tokens(original_question, encoder)
|
||||||
remaining_tokens = max_prompt_size - system_message_tokens
|
remaining_tokens = max_prompt_size - system_message_tokens
|
||||||
if remaining_tokens > original_question_tokens:
|
if remaining_tokens > original_question_tokens:
|
||||||
remaining_tokens -= original_question_tokens
|
remaining_tokens -= original_question_tokens
|
||||||
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
truncated_context = encoder.decode(encoder.encode(remaining_context)[:remaining_tokens]).strip()
|
||||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
truncated_content = truncated_context + original_question
|
||||||
else:
|
else:
|
||||||
truncated_message = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip()
|
truncated_content = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip()
|
||||||
messages = [ChatMessage(content=truncated_message, role=messages[0].role)]
|
messages = [ChatMessage(content=[{"type": "text", "text": truncated_content}], role=message_role)]
|
||||||
|
|
||||||
|
truncated_snippet = (
|
||||||
|
f"{truncated_content[:1000]}\n...\n{truncated_content[-1000:]}"
|
||||||
|
if len(truncated_content) > 2000
|
||||||
|
else truncated_content
|
||||||
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message[:1000]}..."
|
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if system_message:
|
if system_message:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
|
|||||||
|
|
||||||
from apscheduler.schedulers.background import BackgroundScheduler
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
from whisper import Whisper
|
from whisper import Whisper
|
||||||
|
|
||||||
from khoj.database.models import ProcessLock
|
from khoj.database.models import ProcessLock
|
||||||
@@ -40,7 +41,7 @@ khoj_version: str = None
|
|||||||
device = get_device()
|
device = get_device()
|
||||||
chat_on_gpu: bool = True
|
chat_on_gpu: bool = True
|
||||||
anonymous_mode: bool = False
|
anonymous_mode: bool = False
|
||||||
pretrained_tokenizers: Dict[str, Any] = dict()
|
pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict()
|
||||||
billing_enabled: bool = (
|
billing_enabled: bool = (
|
||||||
os.getenv("STRIPE_API_KEY") is not None
|
os.getenv("STRIPE_API_KEY") is not None
|
||||||
and os.getenv("STRIPE_SIGNING_SECRET") is not None
|
and os.getenv("STRIPE_SIGNING_SECRET") is not None
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
@@ -5,7 +7,7 @@ from khoj.processor.conversation import utils
|
|||||||
|
|
||||||
|
|
||||||
class TestTruncateMessage:
|
class TestTruncateMessage:
|
||||||
max_prompt_size = 10
|
max_prompt_size = 40
|
||||||
model_name = "gpt-4o-mini"
|
model_name = "gpt-4o-mini"
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
|
|
||||||
@@ -15,45 +17,108 @@ class TestTruncateMessage:
|
|||||||
|
|
||||||
# Act
|
# Act
|
||||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties
|
# The original object has been modified. Verify certain properties
|
||||||
assert len(chat_history) < 50
|
assert len(chat_history) < 50
|
||||||
assert len(chat_history) > 1
|
assert len(chat_history) > 5
|
||||||
assert tokens <= self.max_prompt_size
|
assert tokens <= self.max_prompt_size
|
||||||
|
|
||||||
|
def test_truncate_message_only_oldest_big(self):
|
||||||
|
# Arrange
|
||||||
|
chat_history = generate_chat_history(5)
|
||||||
|
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||||
|
chat_history.append(big_chat_message)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# The original object has been modified. Verify certain properties
|
||||||
|
assert len(chat_history) == 5
|
||||||
|
assert tokens <= self.max_prompt_size
|
||||||
|
|
||||||
|
def test_truncate_message_with_image(self):
|
||||||
|
# Arrange
|
||||||
|
image_content_item = {"type": "image_url", "image_url": {"url": "placeholder"}}
|
||||||
|
content_list = [{"type": "text", "text": f"{index}"} for index in range(100)]
|
||||||
|
content_list += [image_content_item, {"type": "text", "text": "Question?"}]
|
||||||
|
big_chat_message = ChatMessage(role="user", content=content_list)
|
||||||
|
copy_big_chat_message = deepcopy(big_chat_message)
|
||||||
|
chat_history = [big_chat_message]
|
||||||
|
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||||
|
|
||||||
|
# Act
|
||||||
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# The original object has been modified. Verify certain properties
|
||||||
|
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||||
|
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
|
||||||
|
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||||
|
|
||||||
|
def test_truncate_message_with_content_list(self):
|
||||||
|
# Arrange
|
||||||
|
chat_history = generate_chat_history(5)
|
||||||
|
content_list = [{"type": "text", "text": f"{index}"} for index in range(100)]
|
||||||
|
content_list += [{"type": "text", "text": "Question?"}]
|
||||||
|
big_chat_message = ChatMessage(role="user", content=content_list)
|
||||||
|
copy_big_chat_message = deepcopy(big_chat_message)
|
||||||
|
chat_history.insert(0, big_chat_message)
|
||||||
|
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||||
|
|
||||||
|
# Act
|
||||||
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
|
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# The original object has been modified. Verify certain properties
|
||||||
|
assert (
|
||||||
|
len(chat_history) == 1
|
||||||
|
), "Only most recent message should be present as it itself is larger than context size"
|
||||||
|
assert len(truncated_chat_history[0].content) < len(
|
||||||
|
copy_big_chat_message.content
|
||||||
|
), "message content list should be modified"
|
||||||
|
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
|
||||||
|
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||||
|
|
||||||
def test_truncate_message_first_large(self):
|
def test_truncate_message_first_large(self):
|
||||||
# Arrange
|
# Arrange
|
||||||
chat_history = generate_chat_history(5)
|
chat_history = generate_chat_history(5)
|
||||||
big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?")
|
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
chat_history.insert(0, big_chat_message)
|
chat_history.insert(0, big_chat_message)
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties
|
# The original object has been modified. Verify certain properties
|
||||||
assert len(chat_history) == 1
|
assert (
|
||||||
assert truncated_chat_history[0] != copy_big_chat_message
|
len(chat_history) == 1
|
||||||
assert tokens <= self.max_prompt_size
|
), "Only most recent message should be present as it itself is larger than context size"
|
||||||
|
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||||
|
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||||
|
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
|
||||||
|
|
||||||
def test_truncate_message_last_large(self):
|
def test_truncate_message_large_system_message_first(self):
|
||||||
# Arrange
|
# Arrange
|
||||||
chat_history = generate_chat_history(5)
|
chat_history = generate_chat_history(5)
|
||||||
chat_history[0].role = "system" # Mark the first message as system message
|
chat_history[0].role = "system" # Mark the first message as system message
|
||||||
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
|
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
|
|
||||||
chat_history.insert(0, big_chat_message)
|
chat_history.insert(0, big_chat_message)
|
||||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties.
|
# The original object has been modified. Verify certain properties.
|
||||||
@@ -62,46 +127,52 @@ class TestTruncateMessage:
|
|||||||
) # Because the system_prompt is popped off from the chat_messages list
|
) # Because the system_prompt is popped off from the chat_messages list
|
||||||
assert len(truncated_chat_history) < 10
|
assert len(truncated_chat_history) < 10
|
||||||
assert len(truncated_chat_history) > 1
|
assert len(truncated_chat_history) > 1
|
||||||
assert truncated_chat_history[0] != copy_big_chat_message
|
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||||
assert initial_tokens > self.max_prompt_size
|
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||||
assert final_tokens <= self.max_prompt_size
|
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||||
|
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
|
||||||
|
|
||||||
def test_truncate_single_large_non_system_message(self):
|
def test_truncate_single_large_non_system_message(self):
|
||||||
# Arrange
|
# Arrange
|
||||||
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
|
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
chat_messages = [big_chat_message]
|
chat_messages = [big_chat_message]
|
||||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages])
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||||
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties
|
# The original object has been modified. Verify certain properties
|
||||||
assert initial_tokens > self.max_prompt_size
|
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||||
assert final_tokens <= self.max_prompt_size
|
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
|
||||||
assert len(chat_messages) == 1
|
assert (
|
||||||
assert truncated_chat_history[0] != copy_big_chat_message
|
len(chat_messages) == 1
|
||||||
|
), "Only most recent message should be present as it itself is larger than context size"
|
||||||
|
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||||
|
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
|
||||||
|
|
||||||
def test_truncate_single_large_question(self):
|
def test_truncate_single_large_question(self):
|
||||||
# Arrange
|
# Arrange
|
||||||
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
|
big_chat_message_content = [{"type": "text", "text": " ".join(["hi"] * (self.max_prompt_size + 1))}]
|
||||||
big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
|
big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
chat_messages = [big_chat_message]
|
chat_messages = [big_chat_message]
|
||||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages])
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||||
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties
|
# The original object has been modified. Verify certain properties
|
||||||
assert initial_tokens > self.max_prompt_size
|
assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
|
||||||
assert final_tokens <= self.max_prompt_size
|
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
|
||||||
assert len(chat_messages) == 1
|
assert (
|
||||||
assert truncated_chat_history[0] != copy_big_chat_message
|
len(chat_messages) == 1
|
||||||
|
), "Only most recent message should be present as it itself is larger than context size"
|
||||||
|
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
|
||||||
|
|
||||||
|
|
||||||
def test_load_complex_raw_json_string():
|
def test_load_complex_raw_json_string():
|
||||||
@@ -116,12 +187,12 @@ def test_load_complex_raw_json_string():
|
|||||||
assert parsed_json == expeced_json
|
assert parsed_json == expeced_json
|
||||||
|
|
||||||
|
|
||||||
def generate_content(count):
|
def generate_content(count, suffix=""):
|
||||||
return " ".join([f"{index}" for index, _ in enumerate(range(count))])
|
return [{"type": "text", "text": " ".join([f"{index}" for index, _ in enumerate(range(count))]) + "\n" + suffix}]
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_history(count):
|
def generate_chat_history(count):
|
||||||
return [
|
return [
|
||||||
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}")
|
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=[{"type": "text", "text": f"{index}"}])
|
||||||
for index, _ in enumerate(range(count))
|
for index, _ in enumerate(range(count))
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user