mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Use AutoTokenizer to support more tokenizers
This commit is contained in:
@@ -7,7 +7,7 @@ import tiktoken
|
|||||||
|
|
||||||
# External packages
|
# External packages
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
from transformers import LlamaTokenizerFast
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
import queue
|
import queue
|
||||||
@@ -115,15 +115,13 @@ def generate_chatml_messages_with_context(
|
|||||||
return messages[::-1]
|
return messages[::-1]
|
||||||
|
|
||||||
|
|
||||||
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
|
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name: str) -> list[ChatMessage]:
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
|
|
||||||
if "llama" in model_name:
|
if model_name.startswith("gpt-"):
|
||||||
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
|
|
||||||
elif "gpt" in model_name:
|
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
else:
|
else:
|
||||||
encoder = LlamaTokenizerFast.from_pretrained(tokenizer["default"])
|
encoder = AutoTokenizer.from_pretrained(tokenizer.get(model_name, tokenizer["default"]))
|
||||||
|
|
||||||
system_message = messages.pop()
|
system_message = messages.pop()
|
||||||
system_message_tokens = len(encoder.encode(system_message.content))
|
system_message_tokens = len(encoder.encode(system_message.content))
|
||||||
|
|||||||
Reference in New Issue
Block a user