From 2694734d22c813a8cfd0ba5656e00c83e3a8f10b Mon Sep 17 00:00:00 2001 From: Debanjum Date: Sat, 10 May 2025 18:41:16 -0600 Subject: [PATCH] Update truncation logic to handle multi-part message content --- .../processor/conversation/anthropic/utils.py | 5 +- .../processor/conversation/google/utils.py | 5 +- src/khoj/processor/conversation/utils.py | 123 ++++++++++++---- src/khoj/utils/state.py | 3 +- tests/test_conversation_utils.py | 137 +++++++++++++----- 5 files changed, 207 insertions(+), 66 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index baf8fade..e436ecda 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -203,7 +203,10 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: st system_prompt = system_prompt or "" for message in messages.copy(): if message.role == "system": - system_prompt += message.content + if isinstance(message.content, list): + system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"]) + else: + system_prompt += message.content messages.remove(message) system_prompt = None if is_none_or_empty(system_prompt) else system_prompt diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 9f2be46c..ed37a0b3 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -301,7 +301,10 @@ def format_messages_for_gemini( messages = deepcopy(original_messages) for message in messages.copy(): if message.role == "system": - system_prompt += message.content + if isinstance(message.content, list): + system_prompt += "\n".join([part["text"] for part in message.content if part["type"] == "text"]) + else: + system_prompt += message.content messages.remove(message) system_prompt = None if is_none_or_empty(system_prompt) else system_prompt diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 6e4b62ab..6aaa48e9 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -4,14 +4,12 @@ import logging import math import mimetypes import os -import queue import re import uuid from dataclasses import dataclass from datetime import datetime from enum import Enum from io import BytesIO -from time import perf_counter from typing import Any, Callable, Dict, List, Optional import PIL.Image @@ -20,8 +18,9 @@ import requests import tiktoken import yaml from langchain.schema import ChatMessage +from llama_cpp import LlamaTokenizer from llama_cpp.llama import Llama -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from khoj.database.adapters import ConversationAdapters from khoj.database.models import ChatModel, ClientApplication, KhojUser @@ -382,7 +381,7 @@ def gather_raw_query_files( def generate_chatml_messages_with_context( user_message, - system_message=None, + system_message: str = None, conversation_log={}, model_name="gpt-4o-mini", loaded_model: Optional[Llama] = None, @@ -480,7 +479,7 @@ def generate_chatml_messages_with_context( if len(chatml_messages) >= 3 * lookback_turns: break - messages = [] + messages: list[ChatMessage] = [] if not is_none_or_empty(generated_asset_results): messages.append( @@ -517,6 +516,11 @@ def generate_chatml_messages_with_context( if not is_none_or_empty(system_message): messages.append(ChatMessage(content=system_message, role="system")) + # Normalize message content to list of chatml dictionaries + for message in messages: + if isinstance(message.content, str): + message.content = [{"type": "text", "text": message.content}] + # Truncate oldest messages from conversation history until under max supported prompt size by model messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name) @@ -524,14 +528,11 @@ def generate_chatml_messages_with_context( return messages[::-1] -def truncate_messages( - messages: list[ChatMessage], - max_prompt_size: int, +def get_encoder( model_name: str, loaded_model: Optional[Llama] = None, tokenizer_name=None, -) -> list[ChatMessage]: - """Truncate messages to fit within max prompt size supported by model""" +) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer: default_tokenizer = "gpt-4o" try: @@ -554,6 +555,48 @@ def truncate_messages( logger.debug( f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for model: {model_name} in Khoj settings to improve context stuffing." ) + return encoder + + +def count_tokens( + message_content: str | list[str | dict], + encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer | tiktoken.Encoding, +) -> int: + """ + Count the total number of tokens in a list of messages. + + Assumes each images takes 500 tokens for approximation. + """ + if isinstance(message_content, list): + image_count = 0 + message_content_parts: list[str] = [] + # Collate message content into single string to ease token counting + for part in message_content: + if isinstance(part, dict) and part.get("type") == "text": + message_content_parts.append(part["text"]) + elif isinstance(part, dict) and part.get("type") == "image_url": + image_count += 1 + elif isinstance(part, str): + message_content_parts.append(part) + else: + logger.warning(f"Unknown message type: {part}. Skipping.") + message_content = "\n".join(message_content_parts).rstrip() + return len(encoder.encode(message_content)) + image_count * 500 + elif isinstance(message_content, str): + return len(encoder.encode(message_content)) + else: + return len(encoder.encode(json.dumps(message_content))) + + +def truncate_messages( + messages: list[ChatMessage], + max_prompt_size: int, + model_name: str, + loaded_model: Optional[Llama] = None, + tokenizer_name=None, +) -> list[ChatMessage]: + """Truncate messages to fit within max prompt size supported by model""" + encoder = get_encoder(model_name, loaded_model, tokenizer_name) # Extract system message from messages system_message = None @@ -562,35 +605,55 @@ def truncate_messages( system_message = messages.pop(idx) break - # TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string - system_message_tokens = ( - len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0 - ) - - tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) - # Drop older messages until under max supported prompt size by model # Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.) - while (tokens + system_message_tokens + 4 * len(messages)) > max_prompt_size and len(messages) > 1: - messages.pop() - tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) + system_message_tokens = count_tokens(system_message.content, encoder) if system_message else 0 + tokens = sum([count_tokens(message.content, encoder) for message in messages]) + total_tokens = tokens + system_message_tokens + 4 * len(messages) + + while total_tokens > max_prompt_size and (len(messages) > 1 or len(messages[0].content) > 1): + if len(messages[-1].content) > 1: + # The oldest content part is earlier in content list. So pop from the front. + messages[-1].content.pop(0) + else: + # The oldest message is the last one. So pop from the back. + messages.pop() + tokens = sum([count_tokens(message.content, encoder) for message in messages]) + total_tokens = tokens + system_message_tokens + 4 * len(messages) # Truncate current message if still over max supported prompt size by model - if (tokens + system_message_tokens) > max_prompt_size: - current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else "" - original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else "" - original_question = f"\n{original_question}" - original_question_tokens = len(encoder.encode(original_question)) + total_tokens = tokens + system_message_tokens + 4 * len(messages) + if total_tokens > max_prompt_size: + # At this point, a single message with a single content part of type dict should remain + assert ( + len(messages) == 1 and len(messages[0].content) == 1 and isinstance(messages[0].content[0], dict) + ), "Expected a single message with a single content part remaining at this point in truncation" + + # Collate message content into single string to ease truncation + part = messages[0].content[0] + message_content: str = part["text"] if part["type"] == "text" else json.dumps(part) + message_role = messages[0].role + + remaining_context = "\n".join(message_content.split("\n")[:-1]) + original_question = "\n" + "\n".join(message_content.split("\n")[-1:]) + + original_question_tokens = count_tokens(original_question, encoder) remaining_tokens = max_prompt_size - system_message_tokens if remaining_tokens > original_question_tokens: remaining_tokens -= original_question_tokens - truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() - messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)] + truncated_context = encoder.decode(encoder.encode(remaining_context)[:remaining_tokens]).strip() + truncated_content = truncated_context + original_question else: - truncated_message = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip() - messages = [ChatMessage(content=truncated_message, role=messages[0].role)] + truncated_content = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip() + messages = [ChatMessage(content=[{"type": "text", "text": truncated_content}], role=message_role)] + + truncated_snippet = ( + f"{truncated_content[:1000]}\n...\n{truncated_content[-1000:]}" + if len(truncated_content) > 2000 + else truncated_content + ) logger.debug( - f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message[:1000]}..." + f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_snippet}" ) if system_message: diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 1673dbe3..f96409c2 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List from apscheduler.schedulers.background import BackgroundScheduler from openai import OpenAI +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from whisper import Whisper from khoj.database.models import ProcessLock @@ -40,7 +41,7 @@ khoj_version: str = None device = get_device() chat_on_gpu: bool = True anonymous_mode: bool = False -pretrained_tokenizers: Dict[str, Any] = dict() +pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict() billing_enabled: bool = ( os.getenv("STRIPE_API_KEY") is not None and os.getenv("STRIPE_SIGNING_SECRET") is not None diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index 54fe2a7f..43f805b2 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -1,3 +1,5 @@ +from copy import deepcopy + import tiktoken from langchain.schema import ChatMessage @@ -5,7 +7,7 @@ from khoj.processor.conversation import utils class TestTruncateMessage: - max_prompt_size = 10 + max_prompt_size = 40 model_name = "gpt-4o-mini" encoder = tiktoken.encoding_for_model(model_name) @@ -15,45 +17,108 @@ class TestTruncateMessage: # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) - tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties assert len(chat_history) < 50 - assert len(chat_history) > 1 + assert len(chat_history) > 5 assert tokens <= self.max_prompt_size + def test_truncate_message_only_oldest_big(self): + # Arrange + chat_history = generate_chat_history(5) + big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) + chat_history.append(big_chat_message) + + # Act + truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) + + # Assert + # The original object has been modified. Verify certain properties + assert len(chat_history) == 5 + assert tokens <= self.max_prompt_size + + def test_truncate_message_with_image(self): + # Arrange + image_content_item = {"type": "image_url", "image_url": {"url": "placeholder"}} + content_list = [{"type": "text", "text": f"{index}"} for index in range(100)] + content_list += [image_content_item, {"type": "text", "text": "Question?"}] + big_chat_message = ChatMessage(role="user", content=content_list) + copy_big_chat_message = deepcopy(big_chat_message) + chat_history = [big_chat_message] + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) + + # Act + truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) + + # Assert + # The original object has been modified. Verify certain properties + assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" + assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved" + assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" + + def test_truncate_message_with_content_list(self): + # Arrange + chat_history = generate_chat_history(5) + content_list = [{"type": "text", "text": f"{index}"} for index in range(100)] + content_list += [{"type": "text", "text": "Question?"}] + big_chat_message = ChatMessage(role="user", content=content_list) + copy_big_chat_message = deepcopy(big_chat_message) + chat_history.insert(0, big_chat_message) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) + + # Act + truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) + + # Assert + # The original object has been modified. Verify certain properties + assert ( + len(chat_history) == 1 + ), "Only most recent message should be present as it itself is larger than context size" + assert len(truncated_chat_history[0].content) < len( + copy_big_chat_message.content + ), "message content list should be modified" + assert truncated_chat_history[0].content[-1]["text"] == "Question?", "Query should be preserved" + assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" + def test_truncate_message_first_large(self): # Arrange chat_history = generate_chat_history(5) - big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?") + big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) copy_big_chat_message = big_chat_message.copy() chat_history.insert(0, big_chat_message) - tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) - tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties - assert len(chat_history) == 1 - assert truncated_chat_history[0] != copy_big_chat_message - assert tokens <= self.max_prompt_size + assert ( + len(chat_history) == 1 + ), "Only most recent message should be present as it itself is larger than context size" + assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" + assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" + assert tokens <= self.max_prompt_size, "Truncated message should be within max prompt size" - def test_truncate_message_last_large(self): + def test_truncate_message_large_system_message_first(self): # Arrange chat_history = generate_chat_history(5) chat_history[0].role = "system" # Mark the first message as system message - big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?") + big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) copy_big_chat_message = big_chat_message.copy() chat_history.insert(0, big_chat_message) - initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) + initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_history]) # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) - final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties. @@ -62,46 +127,52 @@ class TestTruncateMessage: ) # Because the system_prompt is popped off from the chat_messages list assert len(truncated_chat_history) < 10 assert len(truncated_chat_history) > 1 - assert truncated_chat_history[0] != copy_big_chat_message - assert initial_tokens > self.max_prompt_size - assert final_tokens <= self.max_prompt_size + assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" + assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" + assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" + assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size" def test_truncate_single_large_non_system_message(self): # Arrange - big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?") + big_chat_message = ChatMessage(role="user", content=generate_content(100, suffix="Question?")) copy_big_chat_message = big_chat_message.copy() chat_messages = [big_chat_message] - initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages]) # Act truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) - final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties - assert initial_tokens > self.max_prompt_size - assert final_tokens <= self.max_prompt_size - assert len(chat_messages) == 1 - assert truncated_chat_history[0] != copy_big_chat_message + assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" + assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size" + assert ( + len(chat_messages) == 1 + ), "Only most recent message should be present as it itself is larger than context size" + assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" + assert truncated_chat_history[0].content[0]["text"].endswith("\nQuestion?"), "Query should be preserved" def test_truncate_single_large_question(self): # Arrange - big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1)) + big_chat_message_content = [{"type": "text", "text": " ".join(["hi"] * (self.max_prompt_size + 1))}] big_chat_message = ChatMessage(role="user", content=big_chat_message_content) copy_big_chat_message = big_chat_message.copy() chat_messages = [big_chat_message] - initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + initial_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in chat_messages]) # Act truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) - final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + final_tokens = sum([utils.count_tokens(message.content, self.encoder) for message in truncated_chat_history]) # Assert # The original object has been modified. Verify certain properties - assert initial_tokens > self.max_prompt_size - assert final_tokens <= self.max_prompt_size - assert len(chat_messages) == 1 - assert truncated_chat_history[0] != copy_big_chat_message + assert initial_tokens > self.max_prompt_size, "Initial tokens should be greater than max prompt size" + assert final_tokens <= self.max_prompt_size, "Final tokens should be within max prompt size" + assert ( + len(chat_messages) == 1 + ), "Only most recent message should be present as it itself is larger than context size" + assert truncated_chat_history[0] != copy_big_chat_message, "Original message should be modified" def test_load_complex_raw_json_string(): @@ -116,12 +187,12 @@ def test_load_complex_raw_json_string(): assert parsed_json == expeced_json -def generate_content(count): - return " ".join([f"{index}" for index, _ in enumerate(range(count))]) +def generate_content(count, suffix=""): + return [{"type": "text", "text": " ".join([f"{index}" for index, _ in enumerate(range(count))]) + "\n" + suffix}] def generate_chat_history(count): return [ - ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}") + ChatMessage(role="user" if index % 2 == 0 else "assistant", content=[{"type": "text", "text": f"{index}"}]) for index, _ in enumerate(range(count)) ]