diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index c9e33c6a..e02b8dfc 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -161,7 +161,7 @@ def llm_thread(g, messages: List[ChatMessage], model: GPT4All): templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content) templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content) prompted_message = templated_system_message + chat_history + templated_user_message - response_iterator = model.generate(prompted_message, streaming=True, max_tokens=2000) + response_iterator = model.generate(prompted_message, streaming=True, max_tokens=1000) for response in response_iterator: g.send(response) g.close() diff --git a/src/khoj/processor/conversation/gpt4all/model_metadata.py b/src/khoj/processor/conversation/gpt4all/model_metadata.py index 7d99a6be..065e3720 100644 --- a/src/khoj/processor/conversation/gpt4all/model_metadata.py +++ b/src/khoj/processor/conversation/gpt4all/model_metadata.py @@ -1,3 +1,3 @@ model_name_to_url = { - "llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q3_K_M.bin" + "llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin" } diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 5be8e8f7..b739217c 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -7,13 +7,15 @@ import tiktoken # External packages from langchain.schema import ChatMessage +from transformers import LlamaTokenizerFast # Internal Packages import queue from khoj.utils.helpers import merge_dicts logger = logging.getLogger(__name__) -max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 850} +max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 2048} +tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"} class ThreadedGenerator: @@ -102,10 +104,12 @@ def generate_chatml_messages_with_context( def truncate_messages(messages, max_prompt_size, model_name): """Truncate messages to fit within max prompt size supported by model""" - try: + + if "llama" in model_name: + encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name]) + else: encoder = tiktoken.encoding_for_model(model_name) - except KeyError: - encoder = tiktoken.encoding_for_model("text-davinci-001") + tokens = sum([len(encoder.encode(message.content)) for message in messages]) while tokens > max_prompt_size and len(messages) > 1: messages.pop() diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e8516c38..3c95448a 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -15,9 +15,13 @@ logger = logging.getLogger(__name__) def perform_chat_checks(): - if state.processor_config.conversation and ( - state.processor_config.conversation.openai_model - or state.processor_config.conversation.gpt4all_model.loaded_model + if ( + state.processor_config + and state.processor_config.conversation + and ( + state.processor_config.conversation.openai_model + or state.processor_config.conversation.gpt4all_model.loaded_model + ) ): return