Fix online chat actor tests, improve offline chat actor tests

The chat actor (and director) tests haven't been looked into in a long
while. They'd gone stale in how they were calling thee functions. And
what was required to run them. Now the online chat actor tests work
again.
This commit is contained in:
Debanjum
2025-06-05 03:13:08 -07:00
parent 2f4160e24b
commit d2c7e5516f
3 changed files with 114 additions and 65 deletions

View File

@@ -52,7 +52,10 @@ def generate_chat_history(message_list):
message_to_log(
user_message,
chat_response,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
{
"context": context,
"intent": {"type": "memory", "query": user_message, "inferred-queries": [user_message]},
},
chat_history=chat_history,
)
return chat_history

View File

@@ -33,9 +33,9 @@ freezegun.configure(extend_ignore_list=["transformers"])
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
def test_extract_question_with_date_filter_from_relative_day(loaded_model, default_user2):
# Act
response = extract_questions("Where did I go for dinner yesterday?", loaded_model=loaded_model)
response = extract_questions("Where did I go for dinner yesterday?", default_user2, loaded_model=loaded_model)
assert len(response) >= 1
@@ -53,9 +53,9 @@ def test_extract_question_with_date_filter_from_relative_day(loaded_model):
@pytest.mark.xfail(reason="Search actor still isn't very date aware nor capable of formatting")
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
def test_extract_question_with_date_filter_from_relative_month(loaded_model, default_user2):
# Act
response = extract_questions("Which countries did I visit last month?", loaded_model=loaded_model)
response = extract_questions("Which countries did I visit last month?", default_user2, loaded_model=loaded_model)
# Assert
assert len(response) >= 1
@@ -75,9 +75,9 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model):
@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_year():
def test_extract_question_with_date_filter_from_relative_year(loaded_model, default_user2):
# Act
response = extract_questions("Which countries have I visited this year?")
response = extract_questions("Which countries have I visited this year?", default_user2, loaded_model=loaded_model)
# Assert
expected_responses = [
@@ -93,9 +93,9 @@ def test_extract_question_with_date_filter_from_relative_year():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_explicit_questions_from_message(loaded_model):
def test_extract_multiple_explicit_questions_from_message(loaded_model, default_user2):
# Act
responses = extract_questions("What is the Sun? What is the Moon?", loaded_model=loaded_model)
responses = extract_questions("What is the Sun? What is the Moon?", default_user2, loaded_model=loaded_model)
# Assert
assert len(responses) >= 2
@@ -105,9 +105,9 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message(loaded_model):
def test_extract_multiple_implicit_questions_from_message(loaded_model, default_user2):
# Act
response = extract_questions("Is Carl taller than Ross?", loaded_model=loaded_model)
response = extract_questions("Is Carl taller than Ross?", default_user2, loaded_model=loaded_model)
# Assert
expected_responses = ["height", "taller", "shorter", "heights", "who"]
@@ -121,7 +121,7 @@ def test_extract_multiple_implicit_questions_from_message(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_question_from_chat_history(loaded_model):
def test_generate_search_query_using_question_from_chat_history(loaded_model, default_user2):
# Arrange
message_list = [
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
@@ -131,6 +131,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
# Act
response = extract_questions(
query,
default_user2,
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
use_history=True,
@@ -168,7 +169,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
def test_generate_search_query_using_answer_from_chat_history(loaded_model, default_user2):
# Arrange
message_list = [
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
@@ -177,6 +178,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
# Act
response = extract_questions(
"Is she a Doctor?",
default_user2,
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
use_history=True,
@@ -197,7 +199,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Search actor unable to create date filter using chat history and notes as context")
@pytest.mark.chatquality
def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model):
def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model, default_user2):
# Arrange
message_list = [
("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []),
@@ -206,6 +208,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
# Act
response = extract_questions(
"What was the Pizza place we ate at over there?",
default_user2,
chat_history=generate_chat_history(message_list),
loaded_model=loaded_model,
)

View File

@@ -4,6 +4,7 @@ import freezegun
import pytest
from freezegun import freeze_time
from khoj.database.models import ChatMessageModel
from khoj.processor.conversation.openai.gpt import converse_openai
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import (
@@ -31,10 +32,12 @@ freezegun.configure(extend_ignore_list=["transformers"])
# Test
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_day():
async def test_extract_question_with_date_filter_from_relative_day(chat_client, default_user2):
# Act
response = extract_questions("Where did I go for dinner yesterday?")
response = await extract_questions("Where did I go for dinner yesterday?", default_user2)
# Assert
expected_responses = [
@@ -49,10 +52,12 @@ def test_extract_question_with_date_filter_from_relative_day():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_month():
async def test_extract_question_with_date_filter_from_relative_month(chat_client, default_user2):
# Act
response = extract_questions("Which countries did I visit last month?")
response = await extract_questions("Which countries did I visit last month?", default_user2)
# Assert
expected_responses = [("dt>='1984-03-01'", "dt<'1984-04-01'"), ("dt>='1984-03-01'", "dt<='1984-03-31'")]
@@ -64,10 +69,12 @@ def test_extract_question_with_date_filter_from_relative_month():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_year():
async def test_extract_question_with_date_filter_from_relative_year(chat_client, default_user2):
# Act
response = extract_questions("Which countries have I visited this year?")
response = await extract_questions("Which countries have I visited this year?", default_user2)
# Assert
expected_responses = [
@@ -83,9 +90,11 @@ def test_extract_question_with_date_filter_from_relative_year():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_explicit_questions_from_message():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_extract_multiple_explicit_questions_from_message(chat_client, default_user2):
# Act
responses = extract_questions("What is the Sun? What is the Moon?")
responses = await extract_questions("What is the Sun? What is the Moon?", default_user2)
# Assert
assert len(responses) >= 2
@@ -96,9 +105,11 @@ def test_extract_multiple_explicit_questions_from_message():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_extract_multiple_implicit_questions_from_message(chat_client, default_user2):
# Act
response = extract_questions("Is Morpheus taller than Neo?")
response = await extract_questions("Is Morpheus taller than Neo?", default_user2)
# Assert
expected_responses = [
@@ -112,14 +123,18 @@ def test_extract_multiple_implicit_questions_from_message():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_question_from_chat_history():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_generate_search_query_using_question_from_chat_history(chat_client, default_user2):
# Arrange
message_list = [
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
]
# Act
responses = extract_questions("Does he have any sons?", conversation_log=populate_chat_history(message_list))
responses = await extract_questions(
"Does he have any sons?", default_user2, chat_history=populate_chat_history(message_list)
)
# Assert
assert all(["Vader" in response for response in responses])
@@ -127,14 +142,18 @@ def test_generate_search_query_using_question_from_chat_history():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_answer_from_chat_history():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_generate_search_query_using_answer_from_chat_history(chat_client, default_user2):
# Arrange
message_list = [
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
]
# Act
responses = extract_questions("Is she a Jedi?", conversation_log=populate_chat_history(message_list))
responses = await extract_questions(
"Is she a Jedi?", default_user2, chat_history=populate_chat_history(message_list)
)
# Assert
assert all(["Leia" in response for response in responses])
@@ -142,14 +161,18 @@ def test_generate_search_query_using_answer_from_chat_history():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_question_and_answer_from_chat_history():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_generate_search_query_using_question_and_answer_from_chat_history(chat_client, default_user2):
# Arrange
message_list = [
("Does Luke Skywalker have any Siblings?", "Yes, Princess Leia", []),
]
# Act
response = extract_questions("Who is their father?", conversation_log=populate_chat_history(message_list))
response = await extract_questions(
"Who is their father?", default_user2, chat_history=populate_chat_history(message_list)
)
# Assert
assert any(["Leia" in response or "Luke" in response for response in response])
@@ -157,14 +180,16 @@ def test_generate_search_query_using_question_and_answer_from_chat_history():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_chat_with_no_chat_history_or_retrieved_content():
# Act
response_gen = converse_openai(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Hello, my name is Testatron. Who are you?",
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
response = "".join([response_chunk.response async for response_chunk in response_gen])
# Assert
expected_responses = ["Khoj", "khoj"]
@@ -176,7 +201,9 @@ def test_chat_with_no_chat_history_or_retrieved_content():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_no_content():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_answer_from_chat_history_and_no_content():
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@@ -187,10 +214,10 @@ def test_answer_from_chat_history_and_no_content():
response_gen = converse_openai(
references=[], # Assume no context retrieved from notes for the user_query
user_query="What is my name?",
conversation_log=populate_chat_history(message_list),
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
response = "".join([response_chunk.response async for response_chunk in response_gen])
# Assert
expected_responses = ["Testatron", "testatron"]
@@ -202,7 +229,9 @@ def test_answer_from_chat_history_and_no_content():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_answer_from_chat_history_and_previously_retrieved_content():
"Chat actor needs to use context in previous notes and chat history to answer question"
# Arrange
message_list = [
@@ -218,10 +247,10 @@ def test_answer_from_chat_history_and_previously_retrieved_content():
response_gen = converse_openai(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
response = "".join([response_chunk.response async for response_chunk in response_gen])
# Assert
assert len(response) > 0
@@ -231,7 +260,9 @@ def test_answer_from_chat_history_and_previously_retrieved_content():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_currently_retrieved_content():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_answer_from_chat_history_and_currently_retrieved_content():
"Chat actor needs to use context across currently retrieved notes and chat history to answer question"
# Arrange
message_list = [
@@ -245,10 +276,10 @@ def test_answer_from_chat_history_and_currently_retrieved_content():
{"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "background.md"}
], # Assume context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
response = "".join([response_chunk.response async for response_chunk in response_gen])
# Assert
assert len(response) > 0
@@ -257,7 +288,9 @@ def test_answer_from_chat_history_and_currently_retrieved_content():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_refuse_answering_unanswerable_question():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_refuse_answering_unanswerable_question():
"Chat actor should not try make up answers to unanswerable questions."
# Arrange
message_list = [
@@ -269,10 +302,10 @@ def test_refuse_answering_unanswerable_question():
response_gen = converse_openai(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
response = "".join([response_chunk.response async for response_chunk in response_gen])
# Assert
expected_responses = [
@@ -292,7 +325,9 @@ def test_refuse_answering_unanswerable_question():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_requires_current_date_awareness():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_answer_requires_current_date_awareness():
"Chat actor should be able to answer questions relative to current date using provided notes"
# Arrange
context = [
@@ -324,7 +359,7 @@ Expenses:Food:Dining 10.00 USD""",
user_query="What did I have for Dinner today?",
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
response = "".join([response_chunk.response async for response_chunk in response_gen])
# Assert
expected_responses = ["tacos", "Tacos"]
@@ -336,7 +371,9 @@ Expenses:Food:Dining 10.00 USD""",
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_requires_date_aware_aggregation_across_provided_notes():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_answer_requires_date_aware_aggregation_across_provided_notes():
"Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
# Arrange
context = [
@@ -368,7 +405,7 @@ Expenses:Food:Dining 10.00 USD""",
user_query="How much did I spend on dining this year?",
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
response = "".join([response_chunk.response async for response_chunk in response_gen])
# Assert
assert len(response) > 0
@@ -377,7 +414,9 @@ Expenses:Food:Dining 10.00 USD""",
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_general_question_not_in_chat_history_or_retrieved_content():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_answer_general_question_not_in_chat_history_or_retrieved_content():
"Chat actor should be able to answer general questions not requiring looking at chat history or notes"
# Arrange
message_list = [
@@ -390,10 +429,10 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
response_gen = converse_openai(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines. Do not say anything else",
conversation_log=populate_chat_history(message_list),
chat_history=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
response = "".join([response_chunk.response async for response_chunk in response_gen])
# Assert
expected_responses = ["test", "bug", "code"]
@@ -405,7 +444,9 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question():
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_ask_for_clarification_if_not_enough_context_in_question():
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
# Arrange
context = [
@@ -432,7 +473,7 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
user_query="How many kids does my older sister have?",
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
response = "".join([response_chunk.response async for response_chunk in response_gen])
# Assert
expected_responses = [
@@ -449,7 +490,9 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_agent_prompt_should_be_used(openai_agent):
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_agent_prompt_should_be_used(openai_agent):
"Chat actor should ask be tuned to think like an accountant based on the agent definition"
# Arrange
context = [
@@ -465,14 +508,14 @@ def test_agent_prompt_should_be_used(openai_agent):
user_query="What did I buy?",
api_key=api_key,
)
no_agent_response = "".join([response_chunk for response_chunk in response_gen])
no_agent_response = "".join([response_chunk.response async for response_chunk in response_gen])
response_gen = converse_openai(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I buy?",
api_key=api_key,
agent=openai_agent,
)
agent_response = "".join([response_chunk for response_chunk in response_gen])
agent_response = "".join([response_chunk.response async for response_chunk in response_gen])
# Assert that the model without the agent prompt does not include the summary of purchases
assert all([expected_response not in no_agent_response for expected_response in expected_responses]), (
@@ -492,7 +535,7 @@ async def test_websearch_with_operators(chat_client, default_user2):
user_query = "Share popular posts on r/worldnews this month"
# Act
responses = await generate_online_subqueries(user_query, {}, None, default_user2)
responses = await generate_online_subqueries(user_query, [], None, default_user2)
# Assert
assert any(
@@ -512,7 +555,7 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_u
user_query = "Do you support image search?"
# Act
responses = await generate_online_subqueries(user_query, {}, None, default_user2)
responses = await generate_online_subqueries(user_query, [], None, default_user2)
# Assert
assert any(
@@ -560,7 +603,7 @@ async def test_select_data_sources_actor_chooses_to_search_notes(
chat_client, user_query, expected_conversation_commands, default_user2
):
# Act
selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2)
selected_conversation_commands = await aget_data_sources_and_output_format(user_query, [], False, default_user2)
# Assert
assert set(expected_conversation_commands["sources"]) == set(selected_conversation_commands["sources"])
@@ -599,7 +642,7 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client, defa
user_query = "Summarize the wikipedia page on the history of the internet"
# Act
urls = await infer_webpage_urls(user_query, {}, None, default_user2)
urls = await infer_webpage_urls(user_query, max_webpages=3, chat_history=[], location_data=None, user=default_user2)
# Assert
assert "https://en.wikipedia.org/wiki/History_of_the_Internet" in urls
@@ -641,7 +684,7 @@ def test_infer_task_scheduling_request(
chat_client, user_query, expected_crontime, expected_qs, unexpected_qs, default_user2
):
# Act
crontime, inferred_query, _ = schedule_query(user_query, {}, default_user2)
crontime, inferred_query, _ = schedule_query(user_query, [], default_user2)
inferred_query = inferred_query.lower()
# Assert
@@ -700,15 +743,15 @@ def test_decision_on_when_to_notify_scheduled_task_results(
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):
# Generate conversation logs
conversation_log = {"chat": []}
chat_history: list[ChatMessageModel] = []
for user_message, gpt_message, context in message_list:
conversation_log["chat"] += message_to_log(
chat_history += message_to_log(
user_message,
gpt_message,
khoj_message_metadata={
"context": context,
"intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'},
"intent": {"query": user_message, "inferred-queries": [user_message]},
},
chat_history=[],
)
return conversation_log
return chat_history