From 70b04d16c0ac2d56dfcfa89b5a292feed0c1e69b Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 13 Mar 2024 16:49:13 +0530 Subject: [PATCH] Test data source, output mode selector, web search query chat actors --- tests/test_openai_chat_actors.py | 90 +++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/tests/test_openai_chat_actors.py b/tests/test_openai_chat_actors.py index 183ffb24..01ae85b9 100644 --- a/tests/test_openai_chat_actors.py +++ b/tests/test_openai_chat_actors.py @@ -7,7 +7,12 @@ from freezegun import freeze_time from khoj.processor.conversation.openai.gpt import converse, extract_questions from khoj.processor.conversation.utils import message_to_log -from khoj.routers.helpers import aget_relevant_output_modes +from khoj.routers.helpers import ( + aget_relevant_information_sources, + aget_relevant_output_modes, + generate_online_subqueries, +) +from khoj.utils.helpers import ConversationCommand # Initialize variables for tests api_key = os.getenv("OPENAI_API_KEY") @@ -435,6 +440,47 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."" ) +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +@freeze_time("2024-04-04", ignore=["transformers"]) +async def test_websearch_with_operators(chat_client): + # Arrange + user_query = "Share popular posts on r/worldnews this month" + + # Act + responses = await generate_online_subqueries(user_query, {}, None) + + # Assert + assert any( + ["reddit.com/r/worldnews" in response for response in responses] + ), "Expected a search query to include site:reddit.com but got: " + str(responses) + + assert any( + ["site:reddit.com" in response for response in responses] + ), "Expected a search query to include site:reddit.com but got: " + str(responses) + + assert any( + ["after:2024/04/01" in response for response in responses] + ), "Expected a search query to include after:2024/04/01 but got: " + str(responses) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_websearch_khoj_website_for_info_about_khoj(chat_client): + # Arrange + user_query = "Do you support image search?" + + # Act + responses = await generate_online_subqueries(user_query, {}, None) + + # Assert + assert any( + ["site:khoj.dev" in response for response in responses] + ), "Expected search query to include site:khoj.dev but got: " + str(responses) + + # ---------------------------------------------------------------------------------------------------- @pytest.mark.anyio @pytest.mark.django_db(transaction=True) @@ -463,6 +509,48 @@ async def test_use_image_response_mode(chat_client): assert mode.value == "image" +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_select_data_sources_actor_chooses_default(chat_client): + # Arrange + user_query = "How can I improve my swimming compared to my last lesson?" + + # Act + conversation_commands = await aget_relevant_information_sources(user_query, {}) + + # Assert + assert ConversationCommand.Default in conversation_commands + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_select_data_sources_actor_chooses_to_search_notes(chat_client): + # Arrange + user_query = "Where did I learn to swim?" + + # Act + conversation_commands = await aget_relevant_information_sources(user_query, {}) + + # Assert + assert ConversationCommand.Notes in conversation_commands + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_select_data_sources_actor_chooses_to_search_online(chat_client): + # Arrange + user_query = "Where is the nearest hospital?" + + # Act + conversation_commands = await aget_relevant_information_sources(user_query, {}) + + # Assert + assert ConversationCommand.Online in conversation_commands + + # Helpers # ---------------------------------------------------------------------------------------------------- def populate_chat_history(message_list):