mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-03 21:29:08 +00:00
Misc. quality improvements for Llama V2
- Fix download url -- was mapping to q3_K_M, but fixed to use q4_K_S - Use a proper Llama Tokenizer for counting tokens for truncation with Llama - Add additional null checks when running
This commit is contained in:
@@ -161,7 +161,7 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
|
||||
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
|
||||
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=2000)
|
||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000)
|
||||
for response in response_iterator:
|
||||
g.send(response)
|
||||
g.close()
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
model_name_to_url = {
|
||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q3_K_M.bin"
|
||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
||||
}
|
||||
|
||||
@@ -7,13 +7,15 @@ import tiktoken
|
||||
|
||||
# External packages
|
||||
from langchain.schema import ChatMessage
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
# Internal Packages
|
||||
import queue
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 850}
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 2048}
|
||||
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
|
||||
|
||||
|
||||
class ThreadedGenerator:
|
||||
@@ -102,10 +104,12 @@ def generate_chatml_messages_with_context(
|
||||
|
||||
def truncate_messages(messages, max_prompt_size, model_name):
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
try:
|
||||
|
||||
if "llama" in model_name:
|
||||
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
|
||||
else:
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
encoder = tiktoken.encoding_for_model("text-davinci-001")
|
||||
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
while tokens > max_prompt_size and len(messages) > 1:
|
||||
messages.pop()
|
||||
|
||||
@@ -15,9 +15,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def perform_chat_checks():
|
||||
if state.processor_config.conversation and (
|
||||
state.processor_config.conversation.openai_model
|
||||
or state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
if (
|
||||
state.processor_config
|
||||
and state.processor_config.conversation
|
||||
and (
|
||||
state.processor_config.conversation.openai_model
|
||||
or state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
)
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user