Replace Falcon 🦅 model with Llama V2 🦙 for offline chat (#352)

* Working example with LlamaV2 running locally on my machine

- Download from huggingface
- Plug in to GPT4All
- Update prompts to fit the llama format

* Add appropriate prompts for extracting questions based on a query based on llama format

* Rename Falcon to Llama and make some improvements to the extract_questions flow

* Do further tuning to extract question prompts and unit tests

* Disable extracting questions dynamically from Llama, as results are still unreliable
This commit is contained in:
sabaimran
2023-07-28 03:51:20 +00:00
committed by GitHub
parent 55965eea7d
commit 124d97c26d
11 changed files with 248 additions and 141 deletions

View File

@@ -16,14 +16,18 @@ from freezegun import freeze_time
from gpt4all import GPT4All
# Internal Packages
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon, extract_questions_falcon
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, extract_questions_offline, filter_questions
from khoj.processor.conversation.gpt4all.utils import download_model
from khoj.processor.conversation.utils import message_to_log
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
@pytest.fixture(scope="session")
def loaded_model():
return GPT4All("ggml-model-gpt4all-falcon-q4_0.bin")
download_model(MODEL_NAME)
return GPT4All(MODEL_NAME)
freezegun.configure(extend_ignore_list=["transformers"])
@@ -35,24 +39,40 @@ freezegun.configure(extend_ignore_list=["transformers"])
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
# Act
response = extract_questions_falcon(
"Where did I go for dinner yesterday?", loaded_model=loaded_model, run_extraction=True
)
response = extract_questions_offline("Where did I go for dinner yesterday?", loaded_model=loaded_model)
assert len(response) >= 1
assert response[-1] == "Where did I go for dinner yesterday?"
assert any(
[
"dt>='1984-04-01'" in response[0] and "dt<'1984-04-02'" in response[0],
"dt>='1984-04-01'" in response[0] and "dt<='1984-04-01'" in response[0],
'dt>="1984-04-01"' in response[0] and 'dt<"1984-04-02"' in response[0],
'dt>="1984-04-01"' in response[0] and 'dt<="1984-04-01"' in response[0],
]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
@pytest.mark.chatquality
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
# Act
response = extract_questions_falcon("Which countries did I visit last month?", loaded_model=loaded_model)
response = extract_questions_offline("Which countries did I visit last month?", loaded_model=loaded_model)
# Assert
assert len(response) == 1
assert response == ["Which countries did I visit last month?"]
assert len(response) >= 1
# The user query should be the last question in the response
assert response[-1] == ["Which countries did I visit last month?"]
assert any(
[
"dt>='1984-03-01'" in response[0] and "dt<'1984-04-01'" in response[0],
"dt>='1984-03-01'" in response[0] and "dt<='1984-03-31'" in response[0],
'dt>="1984-03-01"' in response[0] and 'dt<"1984-04-01"' in response[0],
'dt>="1984-03-01"' in response[0] and 'dt<="1984-03-31"' in response[0],
]
)
# ----------------------------------------------------------------------------------------------------
@@ -60,9 +80,7 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model):
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_year(loaded_model):
# Act
response = extract_questions_falcon(
"Which countries have I visited this year?", loaded_model=loaded_model, run_extraction=True
)
response = extract_questions_offline("Which countries have I visited this year?", loaded_model=loaded_model)
# Assert
assert len(response) >= 1
@@ -73,25 +91,26 @@ def test_extract_question_with_date_filter_from_relative_year(loaded_model):
@pytest.mark.chatquality
def test_extract_multiple_explicit_questions_from_message(loaded_model):
# Act
response = extract_questions_falcon("What is the Sun? What is the Moon?", loaded_model=loaded_model)
response = extract_questions_offline("What is the Sun? What is the Moon?", loaded_model=loaded_model)
# Assert
expected_responses = ["What is the Sun?", "What is the Moon?"]
assert len(response) == 2
assert expected_responses == response
assert len(response) >= 2
assert expected_responses[0] == response[-2]
assert expected_responses[1] == response[-1]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message(loaded_model):
# Act
response = extract_questions_falcon("Is Morpheus taller than Neo?", loaded_model=loaded_model, run_extraction=True)
response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model)
# Assert
expected_responses = [
("morpheus", "neo"),
("morpheus", "neo", "height", "taller", "shorter"),
]
assert len(response) == 2
assert len(response) == 3
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
"Expected two search queries in response but got: " + response[0]
)
@@ -106,18 +125,19 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
]
# Act
response = extract_questions_falcon(
response = extract_questions_offline(
"Does he have any sons?",
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
run_extraction=True,
use_history=True,
)
expected_responses = [
"do not have",
"clarify",
"am sorry",
"Vader",
"sons",
"son",
"Darth",
"children",
]
# Assert
@@ -128,7 +148,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
# @pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
@pytest.mark.chatquality
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
# Arrange
@@ -137,17 +157,24 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
]
# Act
response = extract_questions_falcon(
response = extract_questions_offline(
"Is she a Jedi?",
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
run_extraction=True,
use_history=True,
)
expected_responses = [
"Leia",
"Vader",
"daughter",
]
# Assert
assert len(response) == 1
assert "Leia" in response[0]
assert len(response) >= 1
assert any([expected_response in response[0] for expected_response in expected_responses]), (
"Expected chat actor to mention Darth Vader's daughter, but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@@ -160,10 +187,9 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
]
# Act
response = extract_questions_falcon(
response = extract_questions_offline(
"What was the Pizza place we ate at over there?",
conversation_log=populate_chat_history(message_list),
run_extraction=True,
loaded_model=loaded_model,
)
@@ -185,7 +211,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Hello, my name is Testatron. Who are you?",
loaded_model=loaded_model,
@@ -201,7 +227,6 @@ def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor isn't really good at proper nouns yet.")
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model):
"Chat actor needs to use context in previous notes and chat history to answer question"
@@ -216,7 +241,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model)
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
@@ -241,7 +266,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=[
"Testatron was born on 1st April 1984 in Testville."
], # Assume context retrieved from notes for the user_query
@@ -257,7 +282,6 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor is rather liable to lying.")
@pytest.mark.chatquality
def test_refuse_answering_unanswerable_question(loaded_model):
"Chat actor should not try make up answers to unanswerable questions."
@@ -268,7 +292,7 @@ def test_refuse_answering_unanswerable_question(loaded_model):
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
@@ -309,7 +333,7 @@ Expenses:Food:Dining 10.00 USD""",
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I have for Dinner today?",
loaded_model=loaded_model,
@@ -341,7 +365,7 @@ Expenses:Food:Dining 10.00 USD""",
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="How much did I spend on dining this year?",
loaded_model=loaded_model,
@@ -365,7 +389,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines",
conversation_log=populate_chat_history(message_list),
@@ -382,7 +406,6 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.")
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
@@ -397,7 +420,7 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="How many kids does my older sister have?",
loaded_model=loaded_model,
@@ -411,6 +434,17 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
)
def test_filter_questions():
test_questions = [
"I don't know how to answer that",
"I cannot answer anything about the nuclear secrets",
"Who is on the basketball team?",
]
filtered_questions = filter_questions(test_questions)
assert len(filtered_questions) == 1
assert filtered_questions[0] == "Who is on the basketball team?"
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):