diff --git a/tests/conftest.py b/tests/conftest.py index 1be1b03d..246f5a44 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,3 +33,23 @@ def model_dir(tmp_path_factory): asymmetric.setup(search_config.notes, regenerate=False) return model_dir + + +@pytest.fixture(scope='session') +def search_config(model_dir): + search_config = SearchConfig() + search_config.notes = TextSearchConfig( + input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')], + input_filter = None, + compressed_jsonl = model_dir.joinpath('.notes.jsonl.gz'), + embeddings_file = model_dir.joinpath('.note_embeddings.pt'), + verbose = 2) + + search_config.image = ImageSearchConfig( + input_directory = Path('tests/data'), + embeddings_file = Path('tests/data/.image_embeddings.pt'), + batch_size = 10, + use_xmp_metadata = False, + verbose = 2) + + return search_config diff --git a/tests/test_asymmetric_search.py b/tests/test_asymmetric_search.py new file mode 100644 index 00000000..2d84e336 --- /dev/null +++ b/tests/test_asymmetric_search.py @@ -0,0 +1,37 @@ +# Internal Packages +from src.main import model +from src.search_type import asymmetric + + +# Test +# ---------------------------------------------------------------------------------------------------- +def test_asymmetric_setup(search_config): + # Act + # Regenerate notes embeddings during asymmetric setup + notes_model = asymmetric.setup(search_config.notes, regenerate=True) + + # Assert + assert len(notes_model.entries) == 10 + assert len(notes_model.corpus_embeddings) == 10 + + +# ---------------------------------------------------------------------------------------------------- +def test_asymmetric_search(search_config): + # Arrange + model.notes_search = asymmetric.setup(search_config.notes, regenerate=False) + query = "How to git install application?" + + # Act + hits = asymmetric.query( + query, + model = model.notes_search) + + results = asymmetric.collate_results( + hits, + model.notes_search.entries, + count=1) + + # Assert + # Actual_data should contain "Semantic Search via Emacs" entry + search_result = results[0]["Entry"] + assert "git clone" in search_result diff --git a/tests/test_image_search.py b/tests/test_image_search.py new file mode 100644 index 00000000..5d1155e5 --- /dev/null +++ b/tests/test_image_search.py @@ -0,0 +1,44 @@ +# Internal Packages +from src.main import model +from src.search_type import image_search +from src.utils.helpers import resolve_absolute_path + + +# Test +# ---------------------------------------------------------------------------------------------------- +def test_image_search_setup(search_config): + # Act + # Regenerate image search embeddings during image setup + image_search_model = image_search.setup(search_config.image, regenerate=True) + + # Assert + assert len(image_search_model.image_names) == 3 + assert len(image_search_model.image_embeddings) == 3 + + +# ---------------------------------------------------------------------------------------------------- +def test_image_search(search_config): + # Arrange + model.image_search = image_search.setup(search_config.image, regenerate=False) + query_expected_image_pairs = [("brown kitten next to plant", "kitten_park.jpg"), + ("horse and dog in a farm", "horse_dog.jpg"), + ("A guinea pig eating grass", "guineapig_grass.jpg")] + + # Act + for query, expected_image_name in query_expected_image_pairs: + hits = image_search.query( + query, + count = 1, + model = model.image_search) + + results = image_search.collate_results( + hits, + model.image_search.image_names, + search_config.image.input_directory, + count=1) + + actual_image = results[0]["Entry"] + expected_image = resolve_absolute_path(search_config.image.input_directory.joinpath(expected_image_name)) + + # Assert + assert expected_image == actual_image diff --git a/tests/test_main.py b/tests/test_main.py index 1d81e926..cbf0e0c3 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2,13 +2,11 @@ from pathlib import Path # External Packages -import pytest from fastapi.testclient import TestClient # Internal Packages -from src.main import app, search_config, model +from src.main import app, model, search_config as main_search_config from src.search_type import asymmetric, image_search -from src.utils.config import SearchConfig, TextSearchConfig, ImageSearchConfig from src.utils.helpers import resolve_absolute_path @@ -16,25 +14,6 @@ from src.utils.helpers import resolve_absolute_path # ---------------------------------------------------------------------------------------------------- client = TestClient(app) -@pytest.fixture -def search_config(model_dir): - search_config = SearchConfig() - search_config.notes = TextSearchConfig( - input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')], - input_filter = None, - compressed_jsonl = model_dir.joinpath('.notes.jsonl.gz'), - embeddings_file = model_dir.joinpath('.note_embeddings.pt'), - verbose = 0) - - search_config.image = ImageSearchConfig( - input_directory = Path('tests/data'), - embeddings_file = Path('tests/data/.image_embeddings.pt'), - batch_size = 10, - use_xmp_metadata = False, - verbose = 2) - - return search_config - # Test # ---------------------------------------------------------------------------------------------------- @@ -50,8 +29,9 @@ def test_search_with_invalid_search_type(): # ---------------------------------------------------------------------------------------------------- -def test_search_with_valid_search_type(): +def test_search_with_valid_search_type(search_config): # Arrange + main_search_config.image = search_config.image for search_type in ["notes", "ledger", "music", "image"]: # Act response = client.get(f"/search?q=random&t={search_type}") @@ -69,8 +49,9 @@ def test_regenerate_with_invalid_search_type(): # ---------------------------------------------------------------------------------------------------- -def test_regenerate_with_valid_search_type(): +def test_regenerate_with_valid_search_type(search_config): # Arrange + main_search_config.image = search_config.image for search_type in ["notes", "ledger", "music", "image"]: # Act response = client.get(f"/regenerate?t={search_type}") @@ -78,6 +59,28 @@ def test_regenerate_with_valid_search_type(): assert response.status_code == 200 +# ---------------------------------------------------------------------------------------------------- +def test_image_search(search_config): + # Arrange + main_search_config.image = search_config.image + model.image_search = image_search.setup(search_config.image, regenerate=False) + query_expected_image_pairs = [("brown kitten next to fallen plant", "kitten_park.jpg"), + ("a horse and dog on a leash", "horse_dog.jpg"), + ("A guinea pig eating grass", "guineapig_grass.jpg")] + + # Act + for query, expected_image_name in query_expected_image_pairs: + response = client.get(f"/search?q={query}&n=1&t=image") + + # Assert + assert response.status_code == 200 + actual_image = Path(response.json()[0]["Entry"]) + expected_image = resolve_absolute_path(search_config.image.input_directory.joinpath(expected_image_name)) + + # Assert + assert expected_image == actual_image + + # ---------------------------------------------------------------------------------------------------- def test_notes_search(search_config): # Arrange @@ -89,7 +92,7 @@ def test_notes_search(search_config): # Assert assert response.status_code == 200 - # assert actual_data contains "Semantic Search via Emacs" + # assert actual_data contains "Semantic Search via Emacs" entry search_result = response.json()[0]["Entry"] assert "git clone" in search_result @@ -124,53 +127,3 @@ def test_notes_search_with_exclude_filter(search_config): # assert actual_data does not contains explicitly excluded word "Emacs" search_result = response.json()[0]["Entry"] assert "clone" not in search_result - - -# ---------------------------------------------------------------------------------------------------- -def test_image_search(search_config): - # Arrange - model.image_search = image_search.setup(search_config.image, regenerate=False) - query_expected_image_pairs = [("kitten in a park", "kitten_park.jpg"), - ("horse and dog in a farm", "horse_dog.jpg"), - ("A guinea pig eating grass", "guineapig_grass.jpg")] - - # Act - for query, expected_image_name in query_expected_image_pairs: - hits = image_search.query( - query, - count = 1, - model = model.image_search) - - results = image_search.collate_results( - hits, - model.image_search.image_names, - search_config.image.input_directory, - count=1) - - actual_image = results[0]["Entry"] - expected_image = resolve_absolute_path(search_config.image.input_directory.joinpath(expected_image_name)) - - # Assert - assert expected_image == actual_image - - -# ---------------------------------------------------------------------------------------------------- -def test_asymmetric_setup(search_config): - # Act - # Regenerate notes embeddings during asymmetric setup - notes_model = asymmetric.setup(search_config.notes, regenerate=True) - - # Assert - assert len(notes_model.entries) == 10 - assert len(notes_model.corpus_embeddings) == 10 - - -# ---------------------------------------------------------------------------------------------------- -def test_image_search_setup(search_config): - # Act - # Regenerate image search embeddings during image setup - image_search_model = image_search.setup(search_config.image, regenerate=True) - - # Assert - assert len(image_search_model.image_names) == 3 - assert len(image_search_model.image_embeddings) == 3