mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Drop native offline chat support with llama-cpp-python
It is recommended to chat with open-source models by running an open-source server like Ollama, Llama.cpp on your GPU powered machine or use a commercial provider of open-source models like DeepInfra or OpenRouter. These chat model serving options provide a mature Openai compatible API that already works with Khoj. Directly using offline chat models only worked reasonably with pip install on a machine with GPU. Docker setup of khoj had trouble with accessing GPU. And without GPU access offline chat is too slow. Deprecating support for an offline chat provider directly from within Khoj will reduce code complexity and increase developement velocity. Offline models are subsumed to use existing Openai ai model provider.
This commit is contained in:
@@ -196,17 +196,6 @@ def default_openai_chat_model_option():
|
||||
return chat_model
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def offline_agent():
|
||||
chat_model = ChatModelFactory()
|
||||
return Agent.objects.create(
|
||||
name="Accountant",
|
||||
chat_model=chat_model,
|
||||
personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent. ALWAYS RESPOND WITH A SUMMARY TOTAL OF HOW MUCH MONEY I HAVE SPENT.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def openai_agent():
|
||||
@@ -516,40 +505,6 @@ def client(
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
|
||||
# Initialize app state
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types()
|
||||
|
||||
LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/markdown/*.markdown"],
|
||||
user=default_user2,
|
||||
)
|
||||
|
||||
all_files = fs_syncer.collect_files(user=default_user2)
|
||||
configure_content(default_user2, all_files)
|
||||
|
||||
# Initialize Processor from Config
|
||||
ChatModelFactory(
|
||||
name="bartowski/Meta-Llama-3.1-3B-Instruct-GGUF",
|
||||
tokenizer=None,
|
||||
max_prompt_size=None,
|
||||
model_type="offline",
|
||||
)
|
||||
UserConversationProcessorConfigFactory(user=default_user2)
|
||||
|
||||
state.anonymous_mode = True
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def new_org_file(default_user: KhojUser, content_config: ContentConfig):
|
||||
# Setup
|
||||
|
||||
@@ -19,7 +19,7 @@ from khoj.database.models import (
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
|
||||
|
||||
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.OFFLINE):
|
||||
def get_chat_provider(default: ChatModel.ModelType | None = ChatModel.ModelType.GOOGLE):
|
||||
provider = os.getenv("KHOJ_TEST_CHAT_PROVIDER")
|
||||
if provider and provider in ChatModel.ModelType:
|
||||
return ChatModel.ModelType(provider)
|
||||
@@ -93,7 +93,7 @@ class ChatModelFactory(factory.django.DjangoModelFactory):
|
||||
|
||||
max_prompt_size = 20000
|
||||
tokenizer = None
|
||||
name = "bartowski/Meta-Llama-3.2-3B-Instruct-GGUF"
|
||||
name = "gemini-2.0-flash"
|
||||
model_type = get_chat_provider()
|
||||
ai_model_api = factory.LazyAttribute(lambda obj: AiModelApiFactory() if get_chat_api_key() else None)
|
||||
|
||||
|
||||
@@ -1,610 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from khoj.database.models import ChatModel
|
||||
from khoj.routers.helpers import aget_data_sources_and_output_format, extract_questions
|
||||
from khoj.utils.helpers import ConversationCommand
|
||||
from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider
|
||||
|
||||
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
|
||||
pytestmark = pytest.mark.skipif(
|
||||
SKIP_TESTS,
|
||||
reason="Disable in CI to avoid long test runs.",
|
||||
)
|
||||
|
||||
import freezegun
|
||||
from freezegun import freeze_time
|
||||
|
||||
from khoj.processor.conversation.offline.chat_model import converse_offline
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.utils.constants import default_offline_chat_models
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loaded_model():
|
||||
return download_model(default_offline_chat_models[0], max_tokens=5000)
|
||||
|
||||
|
||||
freezegun.configure(extend_ignore_list=["transformers"])
|
||||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_day(loaded_model, default_user2):
|
||||
# Act
|
||||
response = extract_questions("Where did I go for dinner yesterday?", default_user2, loaded_model=loaded_model)
|
||||
|
||||
assert len(response) >= 1
|
||||
|
||||
assert any(
|
||||
[
|
||||
"dt>='1984-04-01'" in response[0] and "dt<'1984-04-02'" in response[0],
|
||||
"dt>='1984-04-01'" in response[0] and "dt<='1984-04-01'" in response[0],
|
||||
'dt>="1984-04-01"' in response[0] and 'dt<"1984-04-02"' in response[0],
|
||||
'dt>="1984-04-01"' in response[0] and 'dt<="1984-04-01"' in response[0],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Search actor still isn't very date aware nor capable of formatting")
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_month(loaded_model, default_user2):
|
||||
# Act
|
||||
response = extract_questions("Which countries did I visit last month?", default_user2, loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
# The user query should be the last question in the response
|
||||
assert response[-1] == ["Which countries did I visit last month?"]
|
||||
assert any(
|
||||
[
|
||||
"dt>='1984-03-01'" in response[0] and "dt<'1984-04-01'" in response[0],
|
||||
"dt>='1984-03-01'" in response[0] and "dt<='1984-03-31'" in response[0],
|
||||
'dt>="1984-03-01"' in response[0] and 'dt<"1984-04-01"' in response[0],
|
||||
'dt>="1984-03-01"' in response[0] and 'dt<="1984-03-31"' in response[0],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_year(loaded_model, default_user2):
|
||||
# Act
|
||||
response = extract_questions("Which countries have I visited this year?", default_user2, loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
("dt>='1984-01-01'", ""),
|
||||
("dt>='1984-01-01'", "dt<'1985-01-01'"),
|
||||
("dt>='1984-01-01'", "dt<='1984-12-31'"),
|
||||
]
|
||||
assert len(response) == 1
|
||||
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
|
||||
"Expected date filter to limit to 1984 in response but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_explicit_questions_from_message(loaded_model, default_user2):
|
||||
# Act
|
||||
responses = extract_questions("What is the Sun? What is the Moon?", default_user2, loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(responses) >= 2
|
||||
assert ["the Sun" in response for response in responses]
|
||||
assert ["the Moon" in response for response in responses]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_implicit_questions_from_message(loaded_model, default_user2):
|
||||
# Act
|
||||
response = extract_questions("Is Carl taller than Ross?", default_user2, loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = ["height", "taller", "shorter", "heights", "who"]
|
||||
assert len(response) <= 3
|
||||
|
||||
for question in response:
|
||||
assert any([expected_response in question.lower() for expected_response in expected_responses]), (
|
||||
"Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_using_question_from_chat_history(loaded_model, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
|
||||
]
|
||||
query = "Does he have any sons?"
|
||||
|
||||
# Act
|
||||
response = extract_questions(
|
||||
query,
|
||||
default_user2,
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
any_expected_with_barbara = [
|
||||
"sibling",
|
||||
"brother",
|
||||
]
|
||||
|
||||
any_expected_with_anderson = [
|
||||
"son",
|
||||
"sons",
|
||||
"children",
|
||||
"family",
|
||||
]
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
# Ensure the remaining generated search queries use proper nouns and chat history context
|
||||
for question in response:
|
||||
if "Barbara" in question:
|
||||
assert any([expected_relation in question for expected_relation in any_expected_with_barbara]), (
|
||||
"Expected search queries using proper nouns and chat history for context, but got: " + question
|
||||
)
|
||||
elif "Anderson" in question:
|
||||
assert any([expected_response in question for expected_response in any_expected_with_anderson]), (
|
||||
"Expected search queries using proper nouns and chat history for context, but got: " + question
|
||||
)
|
||||
else:
|
||||
assert False, (
|
||||
"Expected search queries using proper nouns and chat history for context, but got: " + question
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_using_answer_from_chat_history(loaded_model, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions(
|
||||
"Is she a Doctor?",
|
||||
default_user2,
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
expected_responses = [
|
||||
"Barbara",
|
||||
"Anderson",
|
||||
]
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
assert any([expected_response in response[0] for expected_response in expected_responses]), (
|
||||
"Expected chat actor to mention person's by name, but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Search actor unable to create date filter using chat history and notes as context")
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions(
|
||||
"What was the Pizza place we ate at over there?",
|
||||
default_user2,
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
("dt>='2000-04-01'", "dt<'2000-05-01'"),
|
||||
("dt>='2000-04-01'", "dt<='2000-04-30'"),
|
||||
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
|
||||
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
|
||||
]
|
||||
assert len(response) == 1
|
||||
assert "Masai Mara" in response[0]
|
||||
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
|
||||
"Expected date filter to limit to April 2000 in response but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.parametrize(
|
||||
"user_query, expected_conversation_commands",
|
||||
[
|
||||
(
|
||||
"Where did I learn to swim?",
|
||||
{"sources": [ConversationCommand.Notes], "output": ConversationCommand.Text},
|
||||
),
|
||||
(
|
||||
"Where is the nearest hospital?",
|
||||
{"sources": [ConversationCommand.Online], "output": ConversationCommand.Text},
|
||||
),
|
||||
(
|
||||
"Summarize the wikipedia page on the history of the internet",
|
||||
{"sources": [ConversationCommand.Webpage], "output": ConversationCommand.Text},
|
||||
),
|
||||
(
|
||||
"How many noble gases are there?",
|
||||
{"sources": [ConversationCommand.General], "output": ConversationCommand.Text},
|
||||
),
|
||||
(
|
||||
"Make a painting incorporating my past diving experiences",
|
||||
{"sources": [ConversationCommand.Notes], "output": ConversationCommand.Image},
|
||||
),
|
||||
(
|
||||
"Create a chart of the weather over the next 7 days in Timbuktu",
|
||||
{"sources": [ConversationCommand.Online, ConversationCommand.Code], "output": ConversationCommand.Text},
|
||||
),
|
||||
(
|
||||
"What's the highest point in this country and have I been there?",
|
||||
{"sources": [ConversationCommand.Online, ConversationCommand.Notes], "output": ConversationCommand.Text},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_select_data_sources_actor_chooses_to_search_notes(
|
||||
client_offline_chat, user_query, expected_conversation_commands, default_user2
|
||||
):
|
||||
# Act
|
||||
selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2)
|
||||
|
||||
# Assert
|
||||
assert set(expected_conversation_commands["sources"]) == set(selected_conversation_commands["sources"])
|
||||
assert expected_conversation_commands["output"] == selected_conversation_commands["output"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
user_query = "What's the latest in the Israel/Palestine conflict?"
|
||||
chat_log = [
|
||||
(
|
||||
"Let's talk about the current events around the world.",
|
||||
"Sure, let's discuss the current events. What would you like to know?",
|
||||
[],
|
||||
),
|
||||
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []),
|
||||
]
|
||||
chat_history = ConversationFactory(user=default_user2, conversation_log=generate_chat_history(chat_log))
|
||||
|
||||
# Act
|
||||
tools = await aget_data_sources_and_output_format(user_query, chat_history, is_task=False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
assert tools == ["online"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Hello, my name is Testatron. Who are you?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Khoj", "khoj", "KHOJ"]
|
||||
assert len(response) > 0
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected assistants name, [K|k]hoj, in response but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model):
|
||||
"Chat actor needs to use context in previous notes and chat history to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
(
|
||||
"When was I born?",
|
||||
"You were born on 1st April 1984.",
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
# Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes
|
||||
assert "Testville" in response
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
|
||||
"Chat actor needs to use context across currently retrieved notes and chat history to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=[
|
||||
{"compiled": "Testatron was born on 1st April 1984 in Testville."}
|
||||
], # Assume context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
assert "Testville" in response
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor lies when it doesn't know the answer")
|
||||
@pytest.mark.chatquality
|
||||
def test_refuse_answering_unanswerable_question(loaded_model):
|
||||
"Chat actor should not try make up answers to unanswerable questions."
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
"don't know",
|
||||
"do not know",
|
||||
"no information",
|
||||
"do not have",
|
||||
"don't have",
|
||||
"cannot answer",
|
||||
"I'm sorry",
|
||||
]
|
||||
assert len(response) > 0
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to say they don't know in response, but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_requires_current_date_awareness(loaded_model):
|
||||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||
# Arrange
|
||||
context = [
|
||||
{
|
||||
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD"""
|
||||
},
|
||||
{
|
||||
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
|
||||
Expenses:Food:Dining 10.00 USD"""
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-04-01 "SuperMercado" "Bananas"
|
||||
Expenses:Food:Groceries 10.00 USD"""
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD"""
|
||||
},
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I have for Dinner today?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["tacos", "Tacos"]
|
||||
assert len(response) > 0
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected [T|t]acos in response, but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(loaded_model):
|
||||
"Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
# Arrange
|
||||
context = [
|
||||
{
|
||||
"compiled": f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD"""
|
||||
},
|
||||
{
|
||||
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
|
||||
Expenses:Food:Dining 10.00 USD"""
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-04-01 "SuperMercado" "Bananas"
|
||||
Expenses:Food:Groceries 10.00 USD"""
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD"""
|
||||
},
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="How much did I spend on dining this year?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
assert "20" in response
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded_model):
|
||||
"Chat actor should be able to answer general questions not requiring looking at chat history or notes"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Write a haiku about unit testing in 3 lines",
|
||||
chat_history=generate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["test", "testing"]
|
||||
assert len(response.splitlines()) >= 3 # haikus are 3 lines long, but Falcon tends to add a lot of new lines.
|
||||
assert any([expected_response in response.lower() for expected_response in expected_responses]), (
|
||||
"Expected [T|t]est in response, but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
|
||||
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
|
||||
# Arrange
|
||||
context = [
|
||||
{
|
||||
"compiled": f"""# Ramya
|
||||
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani."""
|
||||
},
|
||||
{
|
||||
"compiled": f"""# Fang
|
||||
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li."""
|
||||
},
|
||||
{
|
||||
"compiled": f"""# Aiyla
|
||||
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."""
|
||||
},
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="How many kids does my older sister have?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister", "Which one"]
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to ask for clarification in response, but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_agent_prompt_should_be_used(loaded_model, offline_agent):
|
||||
"Chat actor should ask be tuned to think like an accountant based on the agent definition"
|
||||
# Arrange
|
||||
context = [
|
||||
{"compiled": f"""I went to the store and bought some bananas for 2.20"""},
|
||||
{"compiled": f"""I went to the store and bought some apples for 1.30"""},
|
||||
{"compiled": f"""I went to the store and bought some oranges for 6.00"""},
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I buy?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert that the model without the agent prompt does not include the summary of purchases
|
||||
expected_responses = ["9.50", "9.5"]
|
||||
assert all([expected_response not in response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to summarize values of purchases" + response
|
||||
)
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I buy?",
|
||||
loaded_model=loaded_model,
|
||||
agent=offline_agent,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert that the model with the agent prompt does include the summary of purchases
|
||||
expected_responses = ["9.50", "9.5"]
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to summarize values of purchases" + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_chat_does_not_exceed_prompt_size(loaded_model):
|
||||
"Ensure chat context and response together do not exceed max prompt size for the model"
|
||||
# Arrange
|
||||
prompt_size_exceeded_error = "ERROR: The prompt size exceeds the context window size and cannot be processed"
|
||||
context = [{"compiled": " ".join([f"{number}" for number in range(2043)])}]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What numbers come after these?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert prompt_size_exceeded_error not in response, (
|
||||
"Expected chat response to be within prompt limits, but got exceeded error: " + response
|
||||
)
|
||||
@@ -1,726 +0,0 @@
|
||||
import urllib.parse
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from freezegun import freeze_time
|
||||
|
||||
from khoj.database.models import Agent, ChatModel, Entry, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from tests.helpers import ConversationFactory, get_chat_provider
|
||||
|
||||
SKIP_TESTS = get_chat_provider(default=None) != ChatModel.ModelType.OFFLINE
|
||||
pytestmark = pytest.mark.skipif(
|
||||
SKIP_TESTS,
|
||||
reason="Disable in CI to avoid long test runs.",
|
||||
)
|
||||
|
||||
fake = Faker()
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def generate_history(message_list):
|
||||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
for user_message, gpt_message, context in message_list:
|
||||
message_to_log(
|
||||
user_message,
|
||||
gpt_message,
|
||||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
chat_history=conversation_log.get("chat", []),
|
||||
)
|
||||
return conversation_log
|
||||
|
||||
|
||||
def create_conversation(message_list, user, agent=None):
|
||||
# Generate conversation logs
|
||||
conversation_log = generate_history(message_list)
|
||||
# Update Conversation Metadata Logs in Database
|
||||
return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent)
|
||||
|
||||
|
||||
# Tests
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_chat):
|
||||
# Act
|
||||
query = "Hello, my name is Testatron. Who are you?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Khoj", "khoj"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected assistants name, [K|k]hoj, in response but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_chat_with_online_content(client_offline_chat):
|
||||
# Act
|
||||
q = "/online give me the link to paul graham's essay how to do great work"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
"https://paulgraham.com/greatwork.html",
|
||||
"https://www.paulgraham.com/greatwork.html",
|
||||
"http://www.paulgraham.com/greatwork.html",
|
||||
]
|
||||
assert response.status_code == 200
|
||||
assert any(
|
||||
[expected_response in response_message for expected_response in expected_responses]
|
||||
), f"Expected links: {expected_responses}. Actual response: {response_message}"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_chat_with_online_webpage_content(client_offline_chat):
|
||||
# Act
|
||||
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
expected_responses = ["185", "1871", "horse"]
|
||||
assert response.status_code == 200
|
||||
assert any(
|
||||
[expected_response in response_message for expected_response in expected_responses]
|
||||
), f"Expected response with {expected_responses}. But actual response had: {response_message}"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_chat_history(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
q = "What is my name?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Testatron", "testatron"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected [T|t]estatron in response but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
(
|
||||
"When was I born?",
|
||||
"You were born on 1st April 1984.",
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
q = "Where was Xi Li born?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "Fujiang" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
(
|
||||
"When was I born?",
|
||||
"You were born on 1st April 1984.",
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
q = "Where was I born?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# 1. Infer who I am from chat history
|
||||
# 2. Infer I was born in Testville from previously retrieved notes
|
||||
assert "Testville" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
q = "Where was I born?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# Inference in a multi-turn conversation
|
||||
# 1. Infer who I am from chat history
|
||||
# 2. Search for notes about when <my_name_from_chat_history> was born
|
||||
# 3. Extract where I was born from currently retrieved notes
|
||||
assert "Fujiang" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
|
||||
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
q = "Where was I born?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": q, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected chat director to say they don't know in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_using_general_command(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = []
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
query = "/general Where was Xi Li born?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "Fujiang" not in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = []
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
query = "/notes Where was Xi Li born?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "Fujiang" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_using_file_filter(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
|
||||
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
|
||||
message_list = []
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
no_answer_response = client_offline_chat.post(
|
||||
f"/api/chat", json={"q": no_answer_query, "stream": True}
|
||||
).content.decode("utf-8")
|
||||
answer_response = client_offline_chat.post(f"/api/chat", json={"q": answer_query, "stream": True}).content.decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert "Fujiang" not in no_answer_response
|
||||
assert "Fujiang" in answer_response
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_not_known_using_notes_command(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = []
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
query = urllib.parse.quote("/notes Where was Testatron born?")
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response_message == prompts.no_notes_found.format()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_one_file(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "Xi Li.markdown" file to the file filters
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
# pick the file that has "Xi Li.markdown" in the name
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "Birthday Gift for Xiu turning 4.markdown" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = client_offline_chat.post(
|
||||
f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
|
||||
)
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_extra_text(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "Xi Li.markdown" file to the file filters
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
# pick the file that has "Xi Li.markdown" in the name
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "Birthday Gift for Xiu turning 4.markdown" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
|
||||
# Act
|
||||
query = "/summarize tell me about Xiu"
|
||||
response = client_offline_chat.post(
|
||||
f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
|
||||
)
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_multiple_files(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "Xi Li.markdown" file to the file filters
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters", json={"filename": file_list[0], "conversation_id": str(conversation.id)}
|
||||
)
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters", json={"filename": file_list[1], "conversation_id": str(conversation.id)}
|
||||
)
|
||||
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
assert response_message is not None
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_no_files(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response_message = response.json()["response"]
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_different_conversation(client_offline_chat, default_user2: KhojUser):
|
||||
message_list = []
|
||||
conversation1 = create_conversation(message_list, default_user2)
|
||||
conversation2 = create_conversation(message_list, default_user2)
|
||||
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user2, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "Birthday Gift for Xiu turning 4.markdown" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
|
||||
# add file filter to conversation 1.
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation1.id)},
|
||||
)
|
||||
|
||||
query = "/summarize"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation2.id})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
# now make sure that the file filter is still in conversation 1
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation1.id})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
assert response_message != ""
|
||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_nonexistant_file(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# post "imaginary.markdown" file to the file filters
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
|
||||
)
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response_message = response.json()["response"]
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_summarize_diff_user_file(
|
||||
client_offline_chat, default_user: KhojUser, pdf_configured_user1, default_user2: KhojUser
|
||||
):
|
||||
# Arrange
|
||||
message_list = []
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
# Get the pdf file called singlepage.pdf
|
||||
file_list = (
|
||||
Entry.objects.filter(user=default_user, file_source="computer")
|
||||
.distinct("file_path")
|
||||
.values_list("file_path", flat=True)
|
||||
)
|
||||
summarization_file = ""
|
||||
for file in file_list:
|
||||
if "singlepage.pdf" in file:
|
||||
summarization_file = file
|
||||
break
|
||||
assert summarization_file != ""
|
||||
# add singlepage.pdf to the file filters
|
||||
response = client_offline_chat.post(
|
||||
"api/chat/conversation/file-filters",
|
||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||
)
|
||||
|
||||
# Act
|
||||
query = "/summarize"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id})
|
||||
response_message = response.json()["response"]
|
||||
|
||||
# Assert
|
||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@freeze_time("2023-04-01", ignore=["transformers"])
|
||||
def test_answer_requires_current_date_awareness(client_offline_chat):
|
||||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||
# Act
|
||||
query = "Where did I have lunch today?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Arak", "Medellin"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected chat director to say Arak, Medellin, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@freeze_time("2023-04-01", ignore=["transformers"])
|
||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
|
||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
# Act
|
||||
query = "How much did I spend on dining this year?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "26" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
query = "Write a haiku about unit testing. Do not say anything else."
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["test", "Test"]
|
||||
assert response.status_code == 200
|
||||
assert len(response_message.splitlines()) == 3 # haikus are 3 lines long
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected [T|t]est in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2):
|
||||
# Act
|
||||
query = "What is the name of Namitas older son"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
"which of them is the older",
|
||||
"which one is older",
|
||||
"which of them is older",
|
||||
"which one is the older",
|
||||
]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected chat director to ask for clarification in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
query = "What is my name?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Testatron", "testatron"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected [T|t]estatron in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_by_conversation_id(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("What's my favorite color", "Your favorite color is green.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
message_list2 = [
|
||||
("Hello, my name is Julia. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 14th August 1947.", []),
|
||||
("What's my favorite color", "Your favorite color is maroon.", []),
|
||||
("Where was I born?", "You were born in a potato farm.", []),
|
||||
]
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
create_conversation(message_list2, default_user2)
|
||||
|
||||
# Act
|
||||
query = "What is my favorite color?"
|
||||
response = client_offline_chat.post(
|
||||
f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
|
||||
)
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["green"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected green in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not great at adhering to agent instructions yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_by_conversation_id_with_agent(
|
||||
client_offline_chat, default_user2: KhojUser, offline_agent: Agent
|
||||
):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("What's my favorite color", "Your favorite color is green.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []),
|
||||
]
|
||||
conversation = create_conversation(message_list, default_user2, offline_agent)
|
||||
|
||||
# Act
|
||||
query = "/general What did I eat for breakfast?"
|
||||
response = client_offline_chat.post(
|
||||
f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}
|
||||
)
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert that agent only responds with the summary of spending
|
||||
expected_responses = ["13.00", "13", "13.0", "thirteen"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected green in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_chat_history_very_long(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
query = "What is my name?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert len(response_message) > 0
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_requires_multiple_independent_searches(client_offline_chat):
|
||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||
# Act
|
||||
query = "Is Xi older than Namita?"
|
||||
response = client_offline_chat.post(f"/api/chat", json={"q": query, "stream": True})
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected Xi is older than Namita, but got: " + response_message
|
||||
)
|
||||
Reference in New Issue
Block a user