diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 96c4c1c8..7bb86887 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -7,7 +7,7 @@ import tiktoken # External packages from langchain.schema import ChatMessage -from transformers import LlamaTokenizerFast +from transformers import AutoTokenizer # Internal Packages import queue @@ -115,15 +115,13 @@ def generate_chatml_messages_with_context( return messages[::-1] -def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]: +def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name: str) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" - if "llama" in model_name: - encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name]) - elif "gpt" in model_name: + if model_name.startswith("gpt-"): encoder = tiktoken.encoding_for_model(model_name) else: - encoder = LlamaTokenizerFast.from_pretrained(tokenizer["default"]) + encoder = AutoTokenizer.from_pretrained(tokenizer.get(model_name, tokenizer["default"])) system_message = messages.pop() system_message_tokens = len(encoder.encode(system_message.content))