mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 21:29:12 +00:00
Misc. quality improvements for Llama V2
- Fix download url -- was mapping to q3_K_M, but fixed to use q4_K_S - Use a proper Llama Tokenizer for counting tokens for truncation with Llama - Add additional null checks when running
This commit is contained in:
@@ -161,7 +161,7 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
|
|||||||
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
|
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
|
||||||
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
||||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=2000)
|
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000)
|
||||||
for response in response_iterator:
|
for response in response_iterator:
|
||||||
g.send(response)
|
g.send(response)
|
||||||
g.close()
|
g.close()
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
model_name_to_url = {
|
model_name_to_url = {
|
||||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q3_K_M.bin"
|
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,13 +7,15 @@ import tiktoken
|
|||||||
|
|
||||||
# External packages
|
# External packages
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
from transformers import LlamaTokenizerFast
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
import queue
|
import queue
|
||||||
from khoj.utils.helpers import merge_dicts
|
from khoj.utils.helpers import merge_dicts
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 850}
|
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 2048}
|
||||||
|
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
|
||||||
|
|
||||||
|
|
||||||
class ThreadedGenerator:
|
class ThreadedGenerator:
|
||||||
@@ -102,10 +104,12 @@ def generate_chatml_messages_with_context(
|
|||||||
|
|
||||||
def truncate_messages(messages, max_prompt_size, model_name):
|
def truncate_messages(messages, max_prompt_size, model_name):
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
try:
|
|
||||||
|
if "llama" in model_name:
|
||||||
|
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
|
||||||
|
else:
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
except KeyError:
|
|
||||||
encoder = tiktoken.encoding_for_model("text-davinci-001")
|
|
||||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||||
while tokens > max_prompt_size and len(messages) > 1:
|
while tokens > max_prompt_size and len(messages) > 1:
|
||||||
messages.pop()
|
messages.pop()
|
||||||
|
|||||||
@@ -15,9 +15,13 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def perform_chat_checks():
|
def perform_chat_checks():
|
||||||
if state.processor_config.conversation and (
|
if (
|
||||||
state.processor_config.conversation.openai_model
|
state.processor_config
|
||||||
or state.processor_config.conversation.gpt4all_model.loaded_model
|
and state.processor_config.conversation
|
||||||
|
and (
|
||||||
|
state.processor_config.conversation.openai_model
|
||||||
|
or state.processor_config.conversation.gpt4all_model.loaded_model
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user