Merge branch 'master' of github.com:khoj-ai/khoj into features/customize-chat-with-agents

This commit is contained in:
sabaimran
2024-03-15 12:13:28 +05:30
26 changed files with 379 additions and 283 deletions

View File

@@ -7,7 +7,12 @@ from freezegun import freeze_time
from khoj.processor.conversation.openai.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_output_modes
from khoj.routers.helpers import (
aget_relevant_information_sources,
aget_relevant_output_modes,
generate_online_subqueries,
)
from khoj.utils.helpers import ConversationCommand
# Initialize variables for tests
api_key = os.getenv("OPENAI_API_KEY")
@@ -154,33 +159,6 @@ def test_generate_search_query_using_question_and_answer_from_chat_history():
assert "Leia" in response[0] and "Luke" in response[0]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_with_date_and_context_from_chat_history():
# Arrange
message_list = [
("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []),
]
# Act
response = extract_questions(
"What was the Pizza place we ate at over there?", conversation_log=populate_chat_history(message_list)
)
# Assert
expected_responses = [
("dt>='2000-04-01'", "dt<'2000-05-01'"),
("dt>='2000-04-01'", "dt<='2000-04-30'"),
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
]
assert len(response) == 1
assert "Masai Mara" in response[0]
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to April 2000 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content():
@@ -391,7 +369,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
# Act
response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines",
user_query="Write a haiku about unit testing in 3 lines. Do not say anything else",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
@@ -471,6 +449,47 @@ def test_agent_prompt_should_be_used(openai_agent):
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@freeze_time("2024-04-04", ignore=["transformers"])
async def test_websearch_with_operators(chat_client):
# Arrange
user_query = "Share popular posts on r/worldnews this month"
# Act
responses = await generate_online_subqueries(user_query, {}, None)
# Assert
assert any(
["reddit.com/r/worldnews" in response for response in responses]
), "Expected a search query to include site:reddit.com but got: " + str(responses)
assert any(
["site:reddit.com" in response for response in responses]
), "Expected a search query to include site:reddit.com but got: " + str(responses)
assert any(
["after:2024/04/01" in response for response in responses]
), "Expected a search query to include after:2024/04/01 but got: " + str(responses)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_websearch_khoj_website_for_info_about_khoj(chat_client):
# Arrange
user_query = "Do you support image search?"
# Act
responses = await generate_online_subqueries(user_query, {}, None)
# Assert
assert any(
["site:khoj.dev" in response for response in responses]
), "Expected search query to include site:khoj.dev but got: " + str(responses)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@@ -490,7 +509,7 @@ async def test_use_default_response_mode(chat_client):
@pytest.mark.django_db(transaction=True)
async def test_use_image_response_mode(chat_client):
# Arrange
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
user_query = "Paint a scenery in Timbuktu in the winter"
# Act
mode = await aget_relevant_output_modes(user_query, {})
@@ -499,6 +518,34 @@ async def test_use_image_response_mode(chat_client):
assert mode.value == "image"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_select_data_sources_actor_chooses_to_search_notes(chat_client):
# Arrange
user_query = "Where did I learn to swim?"
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {})
# Assert
assert ConversationCommand.Notes in conversation_commands
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_select_data_sources_actor_chooses_to_search_online(chat_client):
# Arrange
user_query = "Where is the nearest hospital?"
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {})
# Assert
assert ConversationCommand.Online in conversation_commands
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):