Misc. quality improvements for Llama V2

- Fix download url -- was mapping to q3_K_M, but fixed to use q4_K_S
- Use a proper Llama Tokenizer for counting tokens for truncation with Llama
- Add additional null checks when running
This commit is contained in:
sabaimran
2023-07-31 19:11:20 -07:00
parent ca195097d7
commit 2d6c3cd4fa
4 changed files with 17 additions and 9 deletions

View File

@@ -161,7 +161,7 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
prompted_message = templated_system_message + chat_history + templated_user_message
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=2000)
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000)
for response in response_iterator:
g.send(response)
g.close()

View File

@@ -1,3 +1,3 @@
model_name_to_url = {
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q3_K_M.bin"
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
}

View File

@@ -7,13 +7,15 @@ import tiktoken
# External packages
from langchain.schema import ChatMessage
from transformers import LlamaTokenizerFast
# Internal Packages
import queue
from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__)
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 850}
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 2048}
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
class ThreadedGenerator:
@@ -102,10 +104,12 @@ def generate_chatml_messages_with_context(
def truncate_messages(messages, max_prompt_size, model_name):
"""Truncate messages to fit within max prompt size supported by model"""
try:
if "llama" in model_name:
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
else:
encoder = tiktoken.encoding_for_model(model_name)
except KeyError:
encoder = tiktoken.encoding_for_model("text-davinci-001")
tokens = sum([len(encoder.encode(message.content)) for message in messages])
while tokens > max_prompt_size and len(messages) > 1:
messages.pop()

View File

@@ -15,9 +15,13 @@ logger = logging.getLogger(__name__)
def perform_chat_checks():
if state.processor_config.conversation and (
state.processor_config.conversation.openai_model
or state.processor_config.conversation.gpt4all_model.loaded_model
if (
state.processor_config
and state.processor_config.conversation
and (
state.processor_config.conversation.openai_model
or state.processor_config.conversation.gpt4all_model.loaded_model
)
):
return