mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Replace Falcon 🦅 model with Llama V2 🦙 for offline chat (#352)
* Working example with LlamaV2 running locally on my machine - Download from huggingface - Plug in to GPT4All - Update prompts to fit the llama format * Add appropriate prompts for extracting questions based on a query based on llama format * Rename Falcon to Llama and make some improvements to the extract_questions flow * Do further tuning to extract question prompts and unit tests * Disable extracting questions dynamically from Llama, as results are still unreliable
This commit is contained in:
@@ -16,14 +16,18 @@ from freezegun import freeze_time
|
||||
from gpt4all import GPT4All
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon, extract_questions_falcon
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, extract_questions_offline, filter_questions
|
||||
from khoj.processor.conversation.gpt4all.utils import download_model
|
||||
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
|
||||
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loaded_model():
|
||||
return GPT4All("ggml-model-gpt4all-falcon-q4_0.bin")
|
||||
download_model(MODEL_NAME)
|
||||
return GPT4All(MODEL_NAME)
|
||||
|
||||
|
||||
freezegun.configure(extend_ignore_list=["transformers"])
|
||||
@@ -35,24 +39,40 @@ freezegun.configure(extend_ignore_list=["transformers"])
|
||||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
"Where did I go for dinner yesterday?", loaded_model=loaded_model, run_extraction=True
|
||||
)
|
||||
response = extract_questions_offline("Where did I go for dinner yesterday?", loaded_model=loaded_model)
|
||||
|
||||
assert len(response) >= 1
|
||||
assert response[-1] == "Where did I go for dinner yesterday?"
|
||||
|
||||
assert any(
|
||||
[
|
||||
"dt>='1984-04-01'" in response[0] and "dt<'1984-04-02'" in response[0],
|
||||
"dt>='1984-04-01'" in response[0] and "dt<='1984-04-01'" in response[0],
|
||||
'dt>="1984-04-01"' in response[0] and 'dt<"1984-04-02"' in response[0],
|
||||
'dt>="1984-04-01"' in response[0] and 'dt<="1984-04-01"' in response[0],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon("Which countries did I visit last month?", loaded_model=loaded_model)
|
||||
response = extract_questions_offline("Which countries did I visit last month?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(response) == 1
|
||||
assert response == ["Which countries did I visit last month?"]
|
||||
assert len(response) >= 1
|
||||
# The user query should be the last question in the response
|
||||
assert response[-1] == ["Which countries did I visit last month?"]
|
||||
assert any(
|
||||
[
|
||||
"dt>='1984-03-01'" in response[0] and "dt<'1984-04-01'" in response[0],
|
||||
"dt>='1984-03-01'" in response[0] and "dt<='1984-03-31'" in response[0],
|
||||
'dt>="1984-03-01"' in response[0] and 'dt<"1984-04-01"' in response[0],
|
||||
'dt>="1984-03-01"' in response[0] and 'dt<="1984-03-31"' in response[0],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@@ -60,9 +80,7 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model):
|
||||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_with_date_filter_from_relative_year(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
"Which countries have I visited this year?", loaded_model=loaded_model, run_extraction=True
|
||||
)
|
||||
response = extract_questions_offline("Which countries have I visited this year?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
@@ -73,25 +91,26 @@ def test_extract_question_with_date_filter_from_relative_year(loaded_model):
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_explicit_questions_from_message(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon("What is the Sun? What is the Moon?", loaded_model=loaded_model)
|
||||
response = extract_questions_offline("What is the Sun? What is the Moon?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = ["What is the Sun?", "What is the Moon?"]
|
||||
assert len(response) == 2
|
||||
assert expected_responses == response
|
||||
assert len(response) >= 2
|
||||
assert expected_responses[0] == response[-2]
|
||||
assert expected_responses[1] == response[-1]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_implicit_questions_from_message(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon("Is Morpheus taller than Neo?", loaded_model=loaded_model, run_extraction=True)
|
||||
response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
("morpheus", "neo"),
|
||||
("morpheus", "neo", "height", "taller", "shorter"),
|
||||
]
|
||||
assert len(response) == 2
|
||||
assert len(response) == 3
|
||||
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
|
||||
"Expected two search queries in response but got: " + response[0]
|
||||
)
|
||||
@@ -106,18 +125,19 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
response = extract_questions_offline(
|
||||
"Does he have any sons?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
run_extraction=True,
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
expected_responses = [
|
||||
"do not have",
|
||||
"clarify",
|
||||
"am sorry",
|
||||
"Vader",
|
||||
"sons",
|
||||
"son",
|
||||
"Darth",
|
||||
"children",
|
||||
]
|
||||
|
||||
# Assert
|
||||
@@ -128,7 +148,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
|
||||
# @pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
||||
# Arrange
|
||||
@@ -137,17 +157,24 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
response = extract_questions_offline(
|
||||
"Is she a Jedi?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
run_extraction=True,
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
expected_responses = [
|
||||
"Leia",
|
||||
"Vader",
|
||||
"daughter",
|
||||
]
|
||||
|
||||
# Assert
|
||||
assert len(response) == 1
|
||||
assert "Leia" in response[0]
|
||||
assert len(response) >= 1
|
||||
assert any([expected_response in response[0] for expected_response in expected_responses]), (
|
||||
"Expected chat actor to mention Darth Vader's daughter, but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@@ -160,10 +187,9 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
response = extract_questions_offline(
|
||||
"What was the Pizza place we ate at over there?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
run_extraction=True,
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
|
||||
@@ -185,7 +211,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
|
||||
@pytest.mark.chatquality
|
||||
def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Hello, my name is Testatron. Who are you?",
|
||||
loaded_model=loaded_model,
|
||||
@@ -201,7 +227,6 @@ def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor isn't really good at proper nouns yet.")
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model):
|
||||
"Chat actor needs to use context in previous notes and chat history to answer question"
|
||||
@@ -216,7 +241,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model)
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
@@ -241,7 +266,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=[
|
||||
"Testatron was born on 1st April 1984 in Testville."
|
||||
], # Assume context retrieved from notes for the user_query
|
||||
@@ -257,7 +282,6 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor is rather liable to lying.")
|
||||
@pytest.mark.chatquality
|
||||
def test_refuse_answering_unanswerable_question(loaded_model):
|
||||
"Chat actor should not try make up answers to unanswerable questions."
|
||||
@@ -268,7 +292,7 @@ def test_refuse_answering_unanswerable_question(loaded_model):
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
@@ -309,7 +333,7 @@ Expenses:Food:Dining 10.00 USD""",
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I have for Dinner today?",
|
||||
loaded_model=loaded_model,
|
||||
@@ -341,7 +365,7 @@ Expenses:Food:Dining 10.00 USD""",
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="How much did I spend on dining this year?",
|
||||
loaded_model=loaded_model,
|
||||
@@ -365,7 +389,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Write a haiku about unit testing in 3 lines",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
@@ -382,7 +406,6 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.")
|
||||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
|
||||
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
|
||||
@@ -397,7 +420,7 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="How many kids does my older sister have?",
|
||||
loaded_model=loaded_model,
|
||||
@@ -411,6 +434,17 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
||||
)
|
||||
|
||||
|
||||
def test_filter_questions():
|
||||
test_questions = [
|
||||
"I don't know how to answer that",
|
||||
"I cannot answer anything about the nuclear secrets",
|
||||
"Who is on the basketball team?",
|
||||
]
|
||||
filtered_questions = filter_questions(test_questions)
|
||||
assert len(filtered_questions) == 1
|
||||
assert filtered_questions[0] == "Who is on the basketball team?"
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
|
||||
Reference in New Issue
Block a user