mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Add basic chat actor test to infer scheduled queries
This commit is contained in:
@@ -12,8 +12,10 @@ from khoj.routers.helpers import (
|
|||||||
aget_relevant_output_modes,
|
aget_relevant_output_modes,
|
||||||
generate_online_subqueries,
|
generate_online_subqueries,
|
||||||
infer_webpage_urls,
|
infer_webpage_urls,
|
||||||
|
schedule_query,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import ConversationCommand
|
from khoj.utils.helpers import ConversationCommand
|
||||||
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
# Initialize variables for tests
|
# Initialize variables for tests
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
@@ -490,71 +492,42 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client):
|
|||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
async def test_use_default_response_mode(chat_client):
|
@pytest.mark.parametrize(
|
||||||
# Arrange
|
"user_query, expected_mode",
|
||||||
user_query = "What's the latest in the Israel/Palestine conflict?"
|
[
|
||||||
|
("What's the latest in the Israel/Palestine conflict?", "default"),
|
||||||
|
("Summarize the latest tech news every Monday evening", "reminder"),
|
||||||
|
("Paint a scenery in Timbuktu in the winter", "image"),
|
||||||
|
("Remind me, when did I last visit the Serengeti?", "default"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_use_default_response_mode(chat_client, user_query, expected_mode):
|
||||||
# Act
|
# Act
|
||||||
mode = await aget_relevant_output_modes(user_query, {})
|
mode = await aget_relevant_output_modes(user_query, {})
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert mode.value == "default"
|
assert mode.value == expected_mode
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
async def test_use_image_response_mode(chat_client):
|
@pytest.mark.parametrize(
|
||||||
# Arrange
|
"user_query, expected_conversation_commands",
|
||||||
user_query = "Paint a scenery in Timbuktu in the winter"
|
[
|
||||||
|
("Where did I learn to swim?", [ConversationCommand.Notes]),
|
||||||
# Act
|
("Where is the nearest hospital?", [ConversationCommand.Online]),
|
||||||
mode = await aget_relevant_output_modes(user_query, {})
|
("Summarize the wikipedia page on the history of the internet", [ConversationCommand.Webpage]),
|
||||||
|
],
|
||||||
# Assert
|
)
|
||||||
assert mode.value == "image"
|
async def test_select_data_sources_actor_chooses_to_search_notes(
|
||||||
|
chat_client, user_query, expected_conversation_commands
|
||||||
|
):
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
@pytest.mark.anyio
|
|
||||||
@pytest.mark.django_db(transaction=True)
|
|
||||||
async def test_select_data_sources_actor_chooses_to_search_notes(chat_client):
|
|
||||||
# Arrange
|
|
||||||
user_query = "Where did I learn to swim?"
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
conversation_commands = await aget_relevant_information_sources(user_query, {})
|
conversation_commands = await aget_relevant_information_sources(user_query, {})
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert ConversationCommand.Notes in conversation_commands
|
assert expected_conversation_commands in conversation_commands
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
@pytest.mark.anyio
|
|
||||||
@pytest.mark.django_db(transaction=True)
|
|
||||||
async def test_select_data_sources_actor_chooses_to_search_online(chat_client):
|
|
||||||
# Arrange
|
|
||||||
user_query = "Where is the nearest hospital?"
|
|
||||||
|
|
||||||
# Act
|
|
||||||
conversation_commands = await aget_relevant_information_sources(user_query, {})
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert ConversationCommand.Online in conversation_commands
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
@pytest.mark.anyio
|
|
||||||
@pytest.mark.django_db(transaction=True)
|
|
||||||
async def test_select_data_sources_actor_chooses_to_read_webpage(chat_client):
|
|
||||||
# Arrange
|
|
||||||
user_query = "Summarize the wikipedia page on the history of the internet"
|
|
||||||
|
|
||||||
# Act
|
|
||||||
conversation_commands = await aget_relevant_information_sources(user_query, {})
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert ConversationCommand.Webpage in conversation_commands
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@@ -571,6 +544,33 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client):
|
|||||||
assert "https://en.wikipedia.org/wiki/History_of_the_Internet" in urls
|
assert "https://en.wikipedia.org/wiki/History_of_the_Internet" in urls
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"user_query, location, expected_crontime, expected_queries",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"Share the weather forecast for the next day at 19:30",
|
||||||
|
("Boston", "MA", "USA"),
|
||||||
|
"30 23 * * *",
|
||||||
|
["weather forecast", "boston"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_infer_task_scheduling_request(chat_client, user_query, location, expected_crontime, expected_queries):
|
||||||
|
# Arrange
|
||||||
|
location_data = LocationData(city=location[0], region=location[1], country=location[2])
|
||||||
|
|
||||||
|
# Act
|
||||||
|
crontime, inferred_query = await schedule_query(user_query, location_data, {})
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert expected_crontime in crontime
|
||||||
|
for query in expected_queries:
|
||||||
|
assert query in inferred_query.lower()
|
||||||
|
|
||||||
|
|
||||||
# Helpers
|
# Helpers
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def populate_chat_history(message_list):
|
def populate_chat_history(message_list):
|
||||||
|
|||||||
Reference in New Issue
Block a user