mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Merge branch 'master' of github.com:khoj-ai/khoj into features/add-agents-ui
This commit is contained in:
@@ -7,7 +7,7 @@ import pytest
|
||||
from scipy.stats import linregress
|
||||
|
||||
from khoj.processor.embeddings import EmbeddingsModel
|
||||
from khoj.processor.tools.online_search import search_with_olostep
|
||||
from khoj.processor.tools.online_search import read_webpage, read_webpage_with_olostep
|
||||
from khoj.utils import helpers
|
||||
|
||||
|
||||
@@ -84,13 +84,29 @@ def test_encode_docs_memory_leak():
|
||||
assert slope < 2, f"Memory leak suspected on {device}. Memory usage increased at ~{slope:.2f} MB per iteration"
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("OLOSTEP_API_KEY") is None, reason="OLOSTEP_API_KEY is not set")
|
||||
def test_olostep_api():
|
||||
@pytest.mark.asyncio
|
||||
async def test_reading_webpage():
|
||||
# Arrange
|
||||
website = "https://en.wikipedia.org/wiki/Great_Chicago_Fire"
|
||||
|
||||
# Act
|
||||
response = search_with_olostep(website)
|
||||
response = await read_webpage(website)
|
||||
|
||||
# Assert
|
||||
assert (
|
||||
"An alarm sent from the area near the fire also failed to register at the courthouse where the fire watchmen were"
|
||||
in response
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("OLOSTEP_API_KEY") is None, reason="OLOSTEP_API_KEY is not set")
|
||||
@pytest.mark.asyncio
|
||||
async def test_reading_webpage_with_olostep():
|
||||
# Arrange
|
||||
website = "https://en.wikipedia.org/wiki/Great_Chicago_Fire"
|
||||
|
||||
# Act
|
||||
response = await read_webpage_with_olostep(website)
|
||||
|
||||
# Assert
|
||||
assert (
|
||||
|
||||
@@ -7,7 +7,12 @@ from freezegun import freeze_time
|
||||
|
||||
from khoj.processor.conversation.openai.gpt import converse, extract_questions
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_output_modes
|
||||
from khoj.routers.helpers import (
|
||||
aget_relevant_information_sources,
|
||||
aget_relevant_output_modes,
|
||||
generate_online_subqueries,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand
|
||||
|
||||
# Initialize variables for tests
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
@@ -154,33 +159,6 @@ def test_generate_search_query_using_question_and_answer_from_chat_history():
|
||||
assert "Leia" in response[0] and "Luke" in response[0]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_with_date_and_context_from_chat_history():
|
||||
# Arrange
|
||||
message_list = [
|
||||
("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions(
|
||||
"What was the Pizza place we ate at over there?", conversation_log=populate_chat_history(message_list)
|
||||
)
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
("dt>='2000-04-01'", "dt<'2000-05-01'"),
|
||||
("dt>='2000-04-01'", "dt<='2000-04-30'"),
|
||||
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
|
||||
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
|
||||
]
|
||||
assert len(response) == 1
|
||||
assert "Masai Mara" in response[0]
|
||||
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
|
||||
"Expected date filter to limit to April 2000 in response but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_chat_with_no_chat_history_or_retrieved_content():
|
||||
@@ -391,7 +369,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
|
||||
# Act
|
||||
response_gen = converse(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Write a haiku about unit testing in 3 lines",
|
||||
user_query="Write a haiku about unit testing in 3 lines. Do not say anything else",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
api_key=api_key,
|
||||
)
|
||||
@@ -471,6 +449,47 @@ def test_agent_prompt_should_be_used(openai_agent):
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@freeze_time("2024-04-04", ignore=["transformers"])
|
||||
async def test_websearch_with_operators(chat_client):
|
||||
# Arrange
|
||||
user_query = "Share popular posts on r/worldnews this month"
|
||||
|
||||
# Act
|
||||
responses = await generate_online_subqueries(user_query, {}, None)
|
||||
|
||||
# Assert
|
||||
assert any(
|
||||
["reddit.com/r/worldnews" in response for response in responses]
|
||||
), "Expected a search query to include site:reddit.com but got: " + str(responses)
|
||||
|
||||
assert any(
|
||||
["site:reddit.com" in response for response in responses]
|
||||
), "Expected a search query to include site:reddit.com but got: " + str(responses)
|
||||
|
||||
assert any(
|
||||
["after:2024/04/01" in response for response in responses]
|
||||
), "Expected a search query to include after:2024/04/01 but got: " + str(responses)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_websearch_khoj_website_for_info_about_khoj(chat_client):
|
||||
# Arrange
|
||||
user_query = "Do you support image search?"
|
||||
|
||||
# Act
|
||||
responses = await generate_online_subqueries(user_query, {}, None)
|
||||
|
||||
# Assert
|
||||
assert any(
|
||||
["site:khoj.dev" in response for response in responses]
|
||||
), "Expected search query to include site:khoj.dev but got: " + str(responses)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@@ -490,7 +509,7 @@ async def test_use_default_response_mode(chat_client):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_image_response_mode(chat_client):
|
||||
# Arrange
|
||||
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
|
||||
user_query = "Paint a scenery in Timbuktu in the winter"
|
||||
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
@@ -499,6 +518,34 @@ async def test_use_image_response_mode(chat_client):
|
||||
assert mode.value == "image"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_select_data_sources_actor_chooses_to_search_notes(chat_client):
|
||||
# Arrange
|
||||
user_query = "Where did I learn to swim?"
|
||||
|
||||
# Act
|
||||
conversation_commands = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert ConversationCommand.Notes in conversation_commands
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_select_data_sources_actor_chooses_to_search_online(chat_client):
|
||||
# Arrange
|
||||
user_query = "Where is the nearest hospital?"
|
||||
|
||||
# Act
|
||||
conversation_commands = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert ConversationCommand.Online in conversation_commands
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
|
||||
@@ -220,9 +220,17 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"]
|
||||
expected_responses = [
|
||||
"don't know",
|
||||
"do not know",
|
||||
"no information",
|
||||
"do not have",
|
||||
"don't have",
|
||||
"where were you born?",
|
||||
]
|
||||
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected chat director to say they don't know in response, but got: " + response_message
|
||||
)
|
||||
|
||||
@@ -328,10 +336,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(
|
||||
f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."&stream=true'
|
||||
)
|
||||
response_message = response.content.decode("utf-8")
|
||||
response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else."&stream=true')
|
||||
response_message = response.content.decode("utf-8").split("### compiled references")[0]
|
||||
|
||||
# Assert
|
||||
expected_responses = ["test", "Test"]
|
||||
@@ -348,8 +354,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
|
||||
# Act
|
||||
|
||||
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"&stream=true')
|
||||
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
@@ -359,9 +365,11 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_
|
||||
"the birth order",
|
||||
"provide more context",
|
||||
"provide me with more context",
|
||||
"don't have that",
|
||||
"haven't provided me",
|
||||
]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected chat director to ask for clarification in response, but got: " + response_message
|
||||
)
|
||||
|
||||
@@ -459,13 +467,18 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
|
||||
def test_answer_requires_multiple_independent_searches(chat_client):
|
||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"&stream=true')
|
||||
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
|
||||
|
||||
# Assert
|
||||
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
|
||||
only_full_name_check = "xi li" in response_message and "namita" not in response_message
|
||||
comparative_statement_check = any(
|
||||
[expected_response in response_message for expected_response in expected_responses]
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
assert only_full_name_check or comparative_statement_check, (
|
||||
"Expected Xi is older than Namita, but got: " + response_message
|
||||
)
|
||||
|
||||
@@ -475,15 +488,22 @@ def test_answer_requires_multiple_independent_searches(chat_client):
|
||||
def test_answer_using_file_filter(chat_client):
|
||||
"Chat should be able to use search filters in the query"
|
||||
# Act
|
||||
query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"')
|
||||
query = urllib.parse.quote(
|
||||
'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
|
||||
)
|
||||
|
||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
|
||||
|
||||
# Assert
|
||||
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
|
||||
only_full_name_check = "xi li" in response_message and "namita" not in response_message
|
||||
comparative_statement_check = any(
|
||||
[expected_response in response_message for expected_response in expected_responses]
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
assert only_full_name_check or comparative_statement_check, (
|
||||
"Expected Xi is older than Namita, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user