Use llama.cpp for offline chat models

- Benefits of moving to llama-cpp-python from gpt4all:
  - Support for all GGUF format chat models
  - Support for AMD, Nvidia, Mac, Vulcan GPU machines (instead of just Vulcan, Mac)
  - Supports models with more capabilities like tools, schema
    enforcement, speculative ddecoding, image gen etc.
- Upgrade default chat model, prompt size, tokenizer for new supported
  chat models

- Load offline chat model when present on disk without requiring internet
  - Load model onto GPU if not disabled and device has GPU
  - Load model onto CPU if loading model onto GPU fails
  - Create helper function to check and load model from disk, when model
    glob is present on disk.

    `Llama.from_pretrained' needs internet to get repo info from
    HuggingFace. This isn't required, if the model is already downloaded

    Didn't find any existing HF or llama.cpp method that looked for model
    glob on disk without internet
This commit is contained in:
Debanjum Singh Solanky
2024-03-16 01:49:44 +05:30
parent 0a7392f6ec
commit 8ca39a436c
12 changed files with 146 additions and 164 deletions

View File

@@ -40,9 +40,9 @@ class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
class Meta:
model = ChatModelOptions
max_prompt_size = 2000
max_prompt_size = 3500
tokenizer = None
chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf"
chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
model_type = "offline"

View File

@@ -5,18 +5,12 @@ import pytest
SKIP_TESTS = True
pytestmark = pytest.mark.skipif(
SKIP_TESTS,
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.",
reason="Disable in CI to avoid long test runs.",
)
import freezegun
from freezegun import freeze_time
try:
from gpt4all import GPT4All
except ModuleNotFoundError as e:
print("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
from khoj.processor.conversation.offline.chat_model import (
converse_offline,
extract_questions_offline,
@@ -25,14 +19,12 @@ from khoj.processor.conversation.offline.chat_model import (
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_output_modes
MODEL_NAME = "mistral-7b-instruct-v0.1.Q4_0.gguf"
from khoj.utils.constants import default_offline_chat_model
@pytest.fixture(scope="session")
def loaded_model():
download_model(MODEL_NAME)
return GPT4All(MODEL_NAME)
return download_model(default_offline_chat_model)
freezegun.configure(extend_ignore_list=["transformers"])
@@ -40,7 +32,6 @@ freezegun.configure(extend_ignore_list=["transformers"])
# Test
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Search actor isn't very date aware nor capable of formatting")
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
@@ -149,20 +140,22 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
message_list = [
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
]
query = "Does he have any sons?"
# Act
response = extract_questions_offline(
"Does he have any sons?",
query,
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
use_history=True,
)
all_expected_in_response = [
"Anderson",
any_expected_with_barbara = [
"sibling",
"brother",
]
any_expected_in_response = [
any_expected_with_anderson = [
"son",
"sons",
"children",
@@ -170,12 +163,21 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
# Assert
assert len(response) >= 1
assert all([expected_response in response[0] for expected_response in all_expected_in_response]), (
"Expected chat actor to ask for clarification in response, but got: " + response[0]
)
assert any([expected_response in response[0] for expected_response in any_expected_in_response]), (
"Expected chat actor to ask for clarification in response, but got: " + response[0]
)
assert response[-1] == query, "Expected last question to be the user query, but got: " + response[-1]
# Ensure the remaining generated search queries use proper nouns and chat history context
for question in response[:-1]:
if "Barbara" in question:
assert any([expected_relation in question for expected_relation in any_expected_with_barbara]), (
"Expected search queries using proper nouns and chat history for context, but got: " + question
)
elif "Anderson" in question:
assert any([expected_response in question for expected_response in any_expected_with_anderson]), (
"Expected search queries using proper nouns and chat history for context, but got: " + question
)
else:
assert False, (
"Expected search queries using proper nouns and chat history for context, but got: " + question
)
# ----------------------------------------------------------------------------------------------------
@@ -312,6 +314,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor lies when it doesn't know the answer")
@pytest.mark.chatquality
def test_refuse_answering_unanswerable_question(loaded_model):
"Chat actor should not try make up answers to unanswerable questions."
@@ -436,7 +439,6 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor doesn't ask clarifying questions when context is insufficient")
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"

View File

@@ -14,7 +14,7 @@ from tests.helpers import ConversationFactory
SKIP_TESTS = True
pytestmark = pytest.mark.skipif(
SKIP_TESTS,
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.",
reason="Disable in CI to avoid long test runs.",
)
fake = Faker()
@@ -47,7 +47,7 @@ def populate_chat_history(message_list, user):
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat):
def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_chat):
# Act
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
response_message = response.content.decode("utf-8")
@@ -338,7 +338,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(client_off
# Assert
assert response.status_code == 200
assert "23" in response_message
assert "26" in response_message
# ----------------------------------------------------------------------------------------------------
@@ -514,7 +514,7 @@ async def test_get_correct_tools_general(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_with_chat_history(client_offline_chat):
async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2):
# Arrange
user_query = "What's the latest in the Israel/Palestine conflict?"
chat_log = [
@@ -525,7 +525,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat):
),
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []),
]
chat_history = populate_chat_history(chat_log)
chat_history = populate_chat_history(chat_log, default_user2)
# Act
tools = await aget_relevant_information_sources(user_query, chat_history)