Use AutoTokenizer to support more tokenizers

This commit is contained in:
Debanjum Singh Solanky
2023-10-14 16:54:52 -07:00
parent 1ad8b150e8
commit 247e75595c

View File

@@ -7,7 +7,7 @@ import tiktoken
# External packages
from langchain.schema import ChatMessage
from transformers import LlamaTokenizerFast
from transformers import AutoTokenizer
# Internal Packages
import queue
@@ -115,15 +115,13 @@ def generate_chatml_messages_with_context(
return messages[::-1]
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name: str) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
if "llama" in model_name:
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
elif "gpt" in model_name:
if model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name)
else:
encoder = LlamaTokenizerFast.from_pretrained(tokenizer["default"])
encoder = AutoTokenizer.from_pretrained(tokenizer.get(model_name, tokenizer["default"]))
system_message = messages.pop()
system_message_tokens = len(encoder.encode(system_message.content))