diff --git a/tests/helpers.py b/tests/helpers.py index fb738015..d3c94abc 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -52,7 +52,10 @@ def generate_chat_history(message_list): message_to_log( user_message, chat_response, - {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, + { + "context": context, + "intent": {"type": "memory", "query": user_message, "inferred-queries": [user_message]}, + }, chat_history=chat_history, ) return chat_history diff --git a/tests/test_offline_chat_actors.py b/tests/test_offline_chat_actors.py index 1278677a..979710b6 100644 --- a/tests/test_offline_chat_actors.py +++ b/tests/test_offline_chat_actors.py @@ -33,9 +33,9 @@ freezegun.configure(extend_ignore_list=["transformers"]) # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality @freeze_time("1984-04-02", ignore=["transformers"]) -def test_extract_question_with_date_filter_from_relative_day(loaded_model): +def test_extract_question_with_date_filter_from_relative_day(loaded_model, default_user2): # Act - response = extract_questions("Where did I go for dinner yesterday?", loaded_model=loaded_model) + response = extract_questions("Where did I go for dinner yesterday?", default_user2, loaded_model=loaded_model) assert len(response) >= 1 @@ -53,9 +53,9 @@ def test_extract_question_with_date_filter_from_relative_day(loaded_model): @pytest.mark.xfail(reason="Search actor still isn't very date aware nor capable of formatting") @pytest.mark.chatquality @freeze_time("1984-04-02", ignore=["transformers"]) -def test_extract_question_with_date_filter_from_relative_month(loaded_model): +def test_extract_question_with_date_filter_from_relative_month(loaded_model, default_user2): # Act - response = extract_questions("Which countries did I visit last month?", loaded_model=loaded_model) + response = extract_questions("Which countries did I visit last month?", default_user2, loaded_model=loaded_model) # Assert assert len(response) >= 1 @@ -75,9 +75,9 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model): @pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting") @pytest.mark.chatquality @freeze_time("1984-04-02", ignore=["transformers"]) -def test_extract_question_with_date_filter_from_relative_year(): +def test_extract_question_with_date_filter_from_relative_year(loaded_model, default_user2): # Act - response = extract_questions("Which countries have I visited this year?") + response = extract_questions("Which countries have I visited this year?", default_user2, loaded_model=loaded_model) # Assert expected_responses = [ @@ -93,9 +93,9 @@ def test_extract_question_with_date_filter_from_relative_year(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_extract_multiple_explicit_questions_from_message(loaded_model): +def test_extract_multiple_explicit_questions_from_message(loaded_model, default_user2): # Act - responses = extract_questions("What is the Sun? What is the Moon?", loaded_model=loaded_model) + responses = extract_questions("What is the Sun? What is the Moon?", default_user2, loaded_model=loaded_model) # Assert assert len(responses) >= 2 @@ -105,9 +105,9 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_extract_multiple_implicit_questions_from_message(loaded_model): +def test_extract_multiple_implicit_questions_from_message(loaded_model, default_user2): # Act - response = extract_questions("Is Carl taller than Ross?", loaded_model=loaded_model) + response = extract_questions("Is Carl taller than Ross?", default_user2, loaded_model=loaded_model) # Assert expected_responses = ["height", "taller", "shorter", "heights", "who"] @@ -121,7 +121,7 @@ def test_extract_multiple_implicit_questions_from_message(loaded_model): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_generate_search_query_using_question_from_chat_history(loaded_model): +def test_generate_search_query_using_question_from_chat_history(loaded_model, default_user2): # Arrange message_list = [ ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []), @@ -131,6 +131,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model): # Act response = extract_questions( query, + default_user2, chat_history=generate_chat_history(message_list), loaded_model=loaded_model, use_history=True, @@ -168,7 +169,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_generate_search_query_using_answer_from_chat_history(loaded_model): +def test_generate_search_query_using_answer_from_chat_history(loaded_model, default_user2): # Arrange message_list = [ ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []), @@ -177,6 +178,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model): # Act response = extract_questions( "Is she a Doctor?", + default_user2, chat_history=generate_chat_history(message_list), loaded_model=loaded_model, use_history=True, @@ -197,7 +199,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model): # ---------------------------------------------------------------------------------------------------- @pytest.mark.xfail(reason="Search actor unable to create date filter using chat history and notes as context") @pytest.mark.chatquality -def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model): +def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model, default_user2): # Arrange message_list = [ ("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []), @@ -206,6 +208,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo # Act response = extract_questions( "What was the Pizza place we ate at over there?", + default_user2, chat_history=generate_chat_history(message_list), loaded_model=loaded_model, ) diff --git a/tests/test_online_chat_actors.py b/tests/test_online_chat_actors.py index 815421ab..4fd4bd0a 100644 --- a/tests/test_online_chat_actors.py +++ b/tests/test_online_chat_actors.py @@ -4,6 +4,7 @@ import freezegun import pytest from freezegun import freeze_time +from khoj.database.models import ChatMessageModel from khoj.processor.conversation.openai.gpt import converse_openai from khoj.processor.conversation.utils import message_to_log from khoj.routers.helpers import ( @@ -31,10 +32,12 @@ freezegun.configure(extend_ignore_list=["transformers"]) # Test # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) @freeze_time("1984-04-02", ignore=["transformers"]) -def test_extract_question_with_date_filter_from_relative_day(): +async def test_extract_question_with_date_filter_from_relative_day(chat_client, default_user2): # Act - response = extract_questions("Where did I go for dinner yesterday?") + response = await extract_questions("Where did I go for dinner yesterday?", default_user2) # Assert expected_responses = [ @@ -49,10 +52,12 @@ def test_extract_question_with_date_filter_from_relative_day(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) @freeze_time("1984-04-02", ignore=["transformers"]) -def test_extract_question_with_date_filter_from_relative_month(): +async def test_extract_question_with_date_filter_from_relative_month(chat_client, default_user2): # Act - response = extract_questions("Which countries did I visit last month?") + response = await extract_questions("Which countries did I visit last month?", default_user2) # Assert expected_responses = [("dt>='1984-03-01'", "dt<'1984-04-01'"), ("dt>='1984-03-01'", "dt<='1984-03-31'")] @@ -64,10 +69,12 @@ def test_extract_question_with_date_filter_from_relative_month(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) @freeze_time("1984-04-02", ignore=["transformers"]) -def test_extract_question_with_date_filter_from_relative_year(): +async def test_extract_question_with_date_filter_from_relative_year(chat_client, default_user2): # Act - response = extract_questions("Which countries have I visited this year?") + response = await extract_questions("Which countries have I visited this year?", default_user2) # Assert expected_responses = [ @@ -83,9 +90,11 @@ def test_extract_question_with_date_filter_from_relative_year(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_extract_multiple_explicit_questions_from_message(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_extract_multiple_explicit_questions_from_message(chat_client, default_user2): # Act - responses = extract_questions("What is the Sun? What is the Moon?") + responses = await extract_questions("What is the Sun? What is the Moon?", default_user2) # Assert assert len(responses) >= 2 @@ -96,9 +105,11 @@ def test_extract_multiple_explicit_questions_from_message(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_extract_multiple_implicit_questions_from_message(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_extract_multiple_implicit_questions_from_message(chat_client, default_user2): # Act - response = extract_questions("Is Morpheus taller than Neo?") + response = await extract_questions("Is Morpheus taller than Neo?", default_user2) # Assert expected_responses = [ @@ -112,14 +123,18 @@ def test_extract_multiple_implicit_questions_from_message(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_generate_search_query_using_question_from_chat_history(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_generate_search_query_using_question_from_chat_history(chat_client, default_user2): # Arrange message_list = [ ("What is the name of Mr. Vader's daughter?", "Princess Leia", []), ] # Act - responses = extract_questions("Does he have any sons?", conversation_log=populate_chat_history(message_list)) + responses = await extract_questions( + "Does he have any sons?", default_user2, chat_history=populate_chat_history(message_list) + ) # Assert assert all(["Vader" in response for response in responses]) @@ -127,14 +142,18 @@ def test_generate_search_query_using_question_from_chat_history(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_generate_search_query_using_answer_from_chat_history(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_generate_search_query_using_answer_from_chat_history(chat_client, default_user2): # Arrange message_list = [ ("What is the name of Mr. Vader's daughter?", "Princess Leia", []), ] # Act - responses = extract_questions("Is she a Jedi?", conversation_log=populate_chat_history(message_list)) + responses = await extract_questions( + "Is she a Jedi?", default_user2, chat_history=populate_chat_history(message_list) + ) # Assert assert all(["Leia" in response for response in responses]) @@ -142,14 +161,18 @@ def test_generate_search_query_using_answer_from_chat_history(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_generate_search_query_using_question_and_answer_from_chat_history(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_generate_search_query_using_question_and_answer_from_chat_history(chat_client, default_user2): # Arrange message_list = [ ("Does Luke Skywalker have any Siblings?", "Yes, Princess Leia", []), ] # Act - response = extract_questions("Who is their father?", conversation_log=populate_chat_history(message_list)) + response = await extract_questions( + "Who is their father?", default_user2, chat_history=populate_chat_history(message_list) + ) # Assert assert any(["Leia" in response or "Luke" in response for response in response]) @@ -157,14 +180,16 @@ def test_generate_search_query_using_question_and_answer_from_chat_history(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_chat_with_no_chat_history_or_retrieved_content(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_chat_with_no_chat_history_or_retrieved_content(): # Act response_gen = converse_openai( references=[], # Assume no context retrieved from notes for the user_query user_query="Hello, my name is Testatron. Who are you?", api_key=api_key, ) - response = "".join([response_chunk for response_chunk in response_gen]) + response = "".join([response_chunk.response async for response_chunk in response_gen]) # Assert expected_responses = ["Khoj", "khoj"] @@ -176,7 +201,9 @@ def test_chat_with_no_chat_history_or_retrieved_content(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_from_chat_history_and_no_content(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_answer_from_chat_history_and_no_content(): # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), @@ -187,10 +214,10 @@ def test_answer_from_chat_history_and_no_content(): response_gen = converse_openai( references=[], # Assume no context retrieved from notes for the user_query user_query="What is my name?", - conversation_log=populate_chat_history(message_list), + chat_history=populate_chat_history(message_list), api_key=api_key, ) - response = "".join([response_chunk for response_chunk in response_gen]) + response = "".join([response_chunk.response async for response_chunk in response_gen]) # Assert expected_responses = ["Testatron", "testatron"] @@ -202,7 +229,9 @@ def test_answer_from_chat_history_and_no_content(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_from_chat_history_and_previously_retrieved_content(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_answer_from_chat_history_and_previously_retrieved_content(): "Chat actor needs to use context in previous notes and chat history to answer question" # Arrange message_list = [ @@ -218,10 +247,10 @@ def test_answer_from_chat_history_and_previously_retrieved_content(): response_gen = converse_openai( references=[], # Assume no context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=populate_chat_history(message_list), + chat_history=populate_chat_history(message_list), api_key=api_key, ) - response = "".join([response_chunk for response_chunk in response_gen]) + response = "".join([response_chunk.response async for response_chunk in response_gen]) # Assert assert len(response) > 0 @@ -231,7 +260,9 @@ def test_answer_from_chat_history_and_previously_retrieved_content(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_from_chat_history_and_currently_retrieved_content(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_answer_from_chat_history_and_currently_retrieved_content(): "Chat actor needs to use context across currently retrieved notes and chat history to answer question" # Arrange message_list = [ @@ -245,10 +276,10 @@ def test_answer_from_chat_history_and_currently_retrieved_content(): {"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "background.md"} ], # Assume context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=populate_chat_history(message_list), + chat_history=populate_chat_history(message_list), api_key=api_key, ) - response = "".join([response_chunk for response_chunk in response_gen]) + response = "".join([response_chunk.response async for response_chunk in response_gen]) # Assert assert len(response) > 0 @@ -257,7 +288,9 @@ def test_answer_from_chat_history_and_currently_retrieved_content(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_refuse_answering_unanswerable_question(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_refuse_answering_unanswerable_question(): "Chat actor should not try make up answers to unanswerable questions." # Arrange message_list = [ @@ -269,10 +302,10 @@ def test_refuse_answering_unanswerable_question(): response_gen = converse_openai( references=[], # Assume no context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=populate_chat_history(message_list), + chat_history=populate_chat_history(message_list), api_key=api_key, ) - response = "".join([response_chunk for response_chunk in response_gen]) + response = "".join([response_chunk.response async for response_chunk in response_gen]) # Assert expected_responses = [ @@ -292,7 +325,9 @@ def test_refuse_answering_unanswerable_question(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_requires_current_date_awareness(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_answer_requires_current_date_awareness(): "Chat actor should be able to answer questions relative to current date using provided notes" # Arrange context = [ @@ -324,7 +359,7 @@ Expenses:Food:Dining 10.00 USD""", user_query="What did I have for Dinner today?", api_key=api_key, ) - response = "".join([response_chunk for response_chunk in response_gen]) + response = "".join([response_chunk.response async for response_chunk in response_gen]) # Assert expected_responses = ["tacos", "Tacos"] @@ -336,7 +371,9 @@ Expenses:Food:Dining 10.00 USD""", # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_requires_date_aware_aggregation_across_provided_notes(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_answer_requires_date_aware_aggregation_across_provided_notes(): "Chat actor should be able to answer questions that require date aware aggregation across multiple notes" # Arrange context = [ @@ -368,7 +405,7 @@ Expenses:Food:Dining 10.00 USD""", user_query="How much did I spend on dining this year?", api_key=api_key, ) - response = "".join([response_chunk for response_chunk in response_gen]) + response = "".join([response_chunk.response async for response_chunk in response_gen]) # Assert assert len(response) > 0 @@ -377,7 +414,9 @@ Expenses:Food:Dining 10.00 USD""", # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_general_question_not_in_chat_history_or_retrieved_content(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_answer_general_question_not_in_chat_history_or_retrieved_content(): "Chat actor should be able to answer general questions not requiring looking at chat history or notes" # Arrange message_list = [ @@ -390,10 +429,10 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(): response_gen = converse_openai( references=[], # Assume no context retrieved from notes for the user_query user_query="Write a haiku about unit testing in 3 lines. Do not say anything else", - conversation_log=populate_chat_history(message_list), + chat_history=populate_chat_history(message_list), api_key=api_key, ) - response = "".join([response_chunk for response_chunk in response_gen]) + response = "".join([response_chunk.response async for response_chunk in response_gen]) # Assert expected_responses = ["test", "bug", "code"] @@ -405,7 +444,9 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_ask_for_clarification_if_not_enough_context_in_question(): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_ask_for_clarification_if_not_enough_context_in_question(): "Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context" # Arrange context = [ @@ -432,7 +473,7 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."" user_query="How many kids does my older sister have?", api_key=api_key, ) - response = "".join([response_chunk for response_chunk in response_gen]) + response = "".join([response_chunk.response async for response_chunk in response_gen]) # Assert expected_responses = [ @@ -449,7 +490,9 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."" # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_agent_prompt_should_be_used(openai_agent): +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_agent_prompt_should_be_used(openai_agent): "Chat actor should ask be tuned to think like an accountant based on the agent definition" # Arrange context = [ @@ -465,14 +508,14 @@ def test_agent_prompt_should_be_used(openai_agent): user_query="What did I buy?", api_key=api_key, ) - no_agent_response = "".join([response_chunk for response_chunk in response_gen]) + no_agent_response = "".join([response_chunk.response async for response_chunk in response_gen]) response_gen = converse_openai( references=context, # Assume context retrieved from notes for the user_query user_query="What did I buy?", api_key=api_key, agent=openai_agent, ) - agent_response = "".join([response_chunk for response_chunk in response_gen]) + agent_response = "".join([response_chunk.response async for response_chunk in response_gen]) # Assert that the model without the agent prompt does not include the summary of purchases assert all([expected_response not in no_agent_response for expected_response in expected_responses]), ( @@ -492,7 +535,7 @@ async def test_websearch_with_operators(chat_client, default_user2): user_query = "Share popular posts on r/worldnews this month" # Act - responses = await generate_online_subqueries(user_query, {}, None, default_user2) + responses = await generate_online_subqueries(user_query, [], None, default_user2) # Assert assert any( @@ -512,7 +555,7 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_u user_query = "Do you support image search?" # Act - responses = await generate_online_subqueries(user_query, {}, None, default_user2) + responses = await generate_online_subqueries(user_query, [], None, default_user2) # Assert assert any( @@ -560,7 +603,7 @@ async def test_select_data_sources_actor_chooses_to_search_notes( chat_client, user_query, expected_conversation_commands, default_user2 ): # Act - selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2) + selected_conversation_commands = await aget_data_sources_and_output_format(user_query, [], False, default_user2) # Assert assert set(expected_conversation_commands["sources"]) == set(selected_conversation_commands["sources"]) @@ -599,7 +642,7 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client, defa user_query = "Summarize the wikipedia page on the history of the internet" # Act - urls = await infer_webpage_urls(user_query, {}, None, default_user2) + urls = await infer_webpage_urls(user_query, max_webpages=3, chat_history=[], location_data=None, user=default_user2) # Assert assert "https://en.wikipedia.org/wiki/History_of_the_Internet" in urls @@ -641,7 +684,7 @@ def test_infer_task_scheduling_request( chat_client, user_query, expected_crontime, expected_qs, unexpected_qs, default_user2 ): # Act - crontime, inferred_query, _ = schedule_query(user_query, {}, default_user2) + crontime, inferred_query, _ = schedule_query(user_query, [], default_user2) inferred_query = inferred_query.lower() # Assert @@ -700,15 +743,15 @@ def test_decision_on_when_to_notify_scheduled_task_results( # ---------------------------------------------------------------------------------------------------- def populate_chat_history(message_list): # Generate conversation logs - conversation_log = {"chat": []} + chat_history: list[ChatMessageModel] = [] for user_message, gpt_message, context in message_list: - conversation_log["chat"] += message_to_log( + chat_history += message_to_log( user_message, gpt_message, khoj_message_metadata={ "context": context, - "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}, + "intent": {"query": user_message, "inferred-queries": [user_message]}, }, chat_history=[], ) - return conversation_log + return chat_history