Update truncation logic to handle multi-part message content

This commit is contained in:
Debanjum
2025-05-10 18:41:16 -06:00
parent a337d9e4b8
commit 2694734d22
5 changed files with 207 additions and 66 deletions

View File

@@ -203,7 +203,10 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st
system_prompt = system_prompt or "" system_prompt = system_prompt or ""
for message in messages.copy(): for message in messages.copy():
if message.role == "system": if message.role == "system":
system_prompt += message.content if isinstance(message.content, list):
system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"])
else:
system_prompt += message.content
messages.remove(message) messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt system_prompt = None if is_none_or_empty(system_prompt) else system_prompt

View File

@@ -301,7 +301,10 @@ def format_messages_for_gemini(
messages = deepcopy(original_messages) messages = deepcopy(original_messages)
for message in messages.copy(): for message in messages.copy():
if message.role == "system": if message.role == "system":
system_prompt += message.content if isinstance(message.content, list):
system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"])
else:
system_prompt += message.content
messages.remove(message) messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt system_prompt = None if is_none_or_empty(system_prompt) else system_prompt

View File

@@ -4,14 +4,12 @@ import logging
import math import math
import mimetypes import mimetypes
import os import os
import queue
import re import re
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from time import perf_counter
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import PIL.Image import PIL.Image
@@ -20,8 +18,9 @@ import requests
import tiktoken import tiktoken
import yaml import yaml
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp import LlamaTokenizer
from llama_cpp.llama import Llama from llama_cpp.llama import Llama
from transformers import AutoTokenizer from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ChatModel, ClientApplication, KhojUser from khoj.database.models import ChatModel, ClientApplication, KhojUser
@@ -382,7 +381,7 @@ def gather_raw_query_files(
def generate_chatml_messages_with_context( def generate_chatml_messages_with_context(
user_message, user_message,
system_message=None, system_message: str = None,
conversation_log={}, conversation_log={},
model_name="gpt-4o-mini", model_name="gpt-4o-mini",
loaded_model: Optional[Llama] = None, loaded_model: Optional[Llama] = None,
@@ -480,7 +479,7 @@ def generate_chatml_messages_with_context(
if len(chatml_messages) >= 3 * lookback_turns: if len(chatml_messages) >= 3 * lookback_turns:
break break
messages = [] messages: list[ChatMessage] = []
if not is_none_or_empty(generated_asset_results): if not is_none_or_empty(generated_asset_results):
messages.append( messages.append(
@@ -517,6 +516,11 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(system_message): if not is_none_or_empty(system_message):
messages.append(ChatMessage(content=system_message, role="system")) messages.append(ChatMessage(content=system_message, role="system"))
# Normalize message content to list of chatml dictionaries
for message in messages:
if isinstance(message.content, str):
message.content = [{"type": "text", "text": message.content}]
# Truncate oldest messages from conversation history until under max supported prompt size by model # Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name) messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
@@ -524,14 +528,11 @@ def generate_chatml_messages_with_context(
return messages[::-1] return messages[::-1]
def truncate_messages( def get_encoder(
messages: list[ChatMessage],
max_prompt_size: int,
model_name: str, model_name: str,
loaded_model: Optional[Llama] = None, loaded_model: Optional[Llama] = None,
tokenizer_name=None, tokenizer_name=None,
) -> list[ChatMessage]: ) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer:
"""Truncate messages to fit within max prompt size supported by model"""
default_tokenizer = "gpt-4o" default_tokenizer = "gpt-4o"
try: try:
@@ -554,6 +555,48 @@ def truncate_messages(
logger.debug( logger.debug(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing." f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing."
) )
return encoder
def count_tokens(
message_content: str | list[str | dict],
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer | tiktoken.Encoding,
) -> int:
"""
Count the total number of tokens in a list of messages.
Assumes each images takes 500 tokens for approximation.
"""
if isinstance(message_content, list):
image_count = 0
message_content_parts: list[str] = []
# Collate message content into single string to ease token counting
for part in message_content:
if isinstance(part, dict) and part.get("type") == "text":
message_content_parts.append(part["text"])
elif isinstance(part, dict) and part.get("type") == "image_url":
image_count += 1
elif isinstance(part, str):
message_content_parts.append(part)
else:
logger.warning(f"Unknown message type: {part}. Skipping.")
message_content = "\n".join(message_content_parts).rstrip()
return len(encoder.encode(message_content)) + image_count * 500
elif isinstance(message_content, str):
return len(encoder.encode(message_content))
else:
return len(encoder.encode(json.dumps(message_content)))
def truncate_messages(
messages: list[ChatMessage],
max_prompt_size: int,
model_name: str,
loaded_model: Optional[Llama] = None,
tokenizer_name=None,
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
encoder = get_encoder(model_name, loaded_model, tokenizer_name)
# Extract system message from messages # Extract system message from messages
system_message = None system_message = None
@@ -562,35 +605,55 @@ def truncate_messages(
system_message = messages.pop(idx) system_message = messages.pop(idx)
break break
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
system_message_tokens = (
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
)
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Drop older messages until under max supported prompt size by model # Drop older messages until under max supported prompt size by model
# Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.) # Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
while (tokens + system_message_tokens + 4 * len(messages)) > max_prompt_size and len(messages) > 1: system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0
messages.pop() tokens = sum([count_tokens(message.content, encoder) for message in messages])
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) total_tokens = tokens + system_message_tokens + 4 * len(messages)
while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1):
if len(messages[-1].content) > 1:
# The oldest content part is earlier in content list. So pop from the front.
messages[-1].content.pop(0)
else:
# The oldest message is the last one. So pop from the back.
messages.pop()
tokens = sum([count_tokens(message.content, encoder) for message in messages])
total_tokens = tokens + system_message_tokens + 4 * len(messages)
# Truncate current message if still over max supported prompt size by model # Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size: total_tokens = tokens + system_message_tokens + 4 * len(messages)
current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else "" if total_tokens > max_prompt_size:
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else "" # At this point, a single message with a single content part of type dict should remain
original_question = f"\n{original_question}" assert (
original_question_tokens = len(encoder.encode(original_question)) len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict)
), "Expected a single message with a single content part remaining at this point in truncation"
# Collate message content into single string to ease truncation
part = messages[0].content[0]
message_content: str = part["text"] if part["type"] == "text" else json.dumps(part)
message_role = messages[0].role
remaining_context = "\n".join(message_content.split("\n")[:-1])
original_question = "\n" + "\n".join(message_content.split("\n")[-1:])
original_question_tokens = count_tokens(original_question, encoder)
remaining_tokens = max_prompt_size - system_message_tokens remaining_tokens = max_prompt_size - system_message_tokens
if remaining_tokens > original_question_tokens: if remaining_tokens > original_question_tokens:
remaining_tokens -= original_question_tokens remaining_tokens -= original_question_tokens
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() truncated_context = encoder.decode(encoder.encode(remaining_context)[:remaining_tokens]).strip()
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)] truncated_content = truncated_context + original_question
else: else:
truncated_message = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip() truncated_content = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip()
messages = [ChatMessage(content=truncated_message, role=messages[0].role)] messages = [ChatMessage(content=[{"type": "text", "text": truncated_content}], role=message_role)]
truncated_snippet = (
f"{truncated_content[:1000]}\n...\n{truncated_content[-1000:]}"
if len(truncated_content) > 2000
else truncated_content
)
logger.debug( logger.debug(
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message[:1000]}..." f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}"
) )
if system_message: if system_message:

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, List
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
from openai import OpenAI from openai import OpenAI
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from whisper import Whisper from whisper import Whisper
from khoj.database.models import ProcessLock from khoj.database.models import ProcessLock
@@ -40,7 +41,7 @@ khoj_version: str = None
device = get_device() device = get_device()
chat_on_gpu: bool = True chat_on_gpu: bool = True
anonymous_mode: bool = False anonymous_mode: bool = False
pretrained_tokenizers: Dict[str, Any] = dict() pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict()
billing_enabled: bool = ( billing_enabled: bool = (
os.getenv("STRIPE_API_KEY") is not None os.getenv("STRIPE_API_KEY") is not None
and os.getenv("STRIPE_SIGNING_SECRET") is not None and os.getenv("STRIPE_SIGNING_SECRET") is not None

View File

@@ -1,3 +1,5 @@
from copy import deepcopy
import tiktoken import tiktoken
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
@@ -5,7 +7,7 @@ from khoj.processor.conversation import utils
class TestTruncateMessage: class TestTruncateMessage:
max_prompt_size = 10 max_prompt_size = 40
model_name = "gpt-4o-mini" model_name = "gpt-4o-mini"
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
@@ -15,45 +17,108 @@ class TestTruncateMessage:
# Act # Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert # Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert len(chat_history) < 50 assert len(chat_history) < 50
assert len(chat_history) > 1 assert len(chat_history) > 5
assert tokens <= self.max_prompt_size assert tokens <= self.max_prompt_size
def test_truncate_message_only_oldest_big(self):
# Arrange
chat_history = generate_chat_history(5)
big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
chat_history.append(big_chat_message)
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_history) == 5
assert tokens <= self.max_prompt_size
def test_truncate_message_with_image(self):
# Arrange
image_content_item = {"type": "image_url", "image_url": {"url": "placeholder"}}
content_list = [{"type": "text", "text": f"{index}"} for index in range(100)]
content_list += [image_content_item, {"type": "text", "text": "Question?"}]
big_chat_message = ChatMessage(role="user", content=content_list)
copy_big_chat_message = deepcopy(big_chat_message)
chat_history = [big_chat_message]
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
def test_truncate_message_with_content_list(self):
# Arrange
chat_history = generate_chat_history(5)
content_list = [{"type": "text", "text": f"{index}"} for index in range(100)]
content_list += [{"type": "text", "text": "Question?"}]
big_chat_message = ChatMessage(role="user", content=content_list)
copy_big_chat_message = deepcopy(big_chat_message)
chat_history.insert(0, big_chat_message)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert (
len(chat_history) == 1
), "Only most recent message should be present as it itself is larger than context size"
assert len(truncated_chat_history[0].content) < len(
copy_big_chat_message.content
), "message content list should be modified"
assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved"
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
def test_truncate_message_first_large(self): def test_truncate_message_first_large(self):
# Arrange # Arrange
chat_history = generate_chat_history(5) chat_history = generate_chat_history(5)
big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?") big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message) chat_history.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act # Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert # Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert len(chat_history) == 1 assert (
assert truncated_chat_history[0] != copy_big_chat_message len(chat_history) == 1
assert tokens <= self.max_prompt_size ), "Only most recent message should be present as it itself is larger than context size"
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size"
def test_truncate_message_last_large(self): def test_truncate_message_large_system_message_first(self):
# Arrange # Arrange
chat_history = generate_chat_history(5) chat_history = generate_chat_history(5)
chat_history[0].role = "system" # Mark the first message as system message chat_history[0].role = "system" # Mark the first message as system message
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?") big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message) chat_history.insert(0, big_chat_message)
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history])
# Act # Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert # Assert
# The original object has been modified. Verify certain properties. # The original object has been modified. Verify certain properties.
@@ -62,46 +127,52 @@ class TestTruncateMessage:
) # Because the system_prompt is popped off from the chat_messages list ) # Because the system_prompt is popped off from the chat_messages list
assert len(truncated_chat_history) < 10 assert len(truncated_chat_history) < 10
assert len(truncated_chat_history) > 1 assert len(truncated_chat_history) > 1
assert truncated_chat_history[0] != copy_big_chat_message assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert initial_tokens > self.max_prompt_size assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
assert final_tokens <= self.max_prompt_size assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
def test_truncate_single_large_non_system_message(self): def test_truncate_single_large_non_system_message(self):
# Arrange # Arrange
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?") big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?"))
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message] chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages])
# Act # Act
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert # Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
assert len(chat_messages) == 1 assert (
assert truncated_chat_history[0] != copy_big_chat_message len(chat_messages) == 1
), "Only most recent message should be present as it itself is larger than context size"
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved"
def test_truncate_single_large_question(self): def test_truncate_single_large_question(self):
# Arrange # Arrange
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1)) big_chat_message_content = [{"type": "text", "text": " ".join(["hi"] * (self.max_prompt_size + 1))}]
big_chat_message = ChatMessage(role="user", content=big_chat_message_content) big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message] chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages])
# Act # Act
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history])
# Assert # Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size"
assert final_tokens <= self.max_prompt_size assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size"
assert len(chat_messages) == 1 assert (
assert truncated_chat_history[0] != copy_big_chat_message len(chat_messages) == 1
), "Only most recent message should be present as it itself is larger than context size"
assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified"
def test_load_complex_raw_json_string(): def test_load_complex_raw_json_string():
@@ -116,12 +187,12 @@ def test_load_complex_raw_json_string():
assert parsed_json == expeced_json assert parsed_json == expeced_json
def generate_content(count): def generate_content(count, suffix=""):
return " ".join([f"{index}" for index, _ in enumerate(range(count))]) return [{"type": "text", "text": " ".join([f"{index}" for index, _ in enumerate(range(count))]) + "\n" + suffix}]
def generate_chat_history(count): def generate_chat_history(count):
return [ return [
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}") ChatMessage(role="user" if index % 2 == 0 else "assistant", content=[{"type": "text", "text": f"{index}"}])
for index, _ in enumerate(range(count)) for index, _ in enumerate(range(count))
] ]