diff --git a/tests/conftest.py b/tests/conftest.py index dfb27b8b..a92d33ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ from khoj.main import app from khoj.configure import configure_processor, configure_routes, configure_search_types from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.search_type import image_search, text_search +from khoj.utils.config import ImageContent, SearchModels, TextContent from khoj.utils.helpers import resolve_absolute_path from khoj.utils.rawconfig import ( ContentConfig, @@ -41,35 +42,49 @@ def search_config() -> SearchConfig: encoder="sentence-transformers/all-MiniLM-L6-v2", cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2", model_directory=model_dir / "symmetric/", + encoder_type=None, ) search_config.asymmetric = TextSearchConfig( encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1", cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2", model_directory=model_dir / "asymmetric/", + encoder_type=None, ) search_config.image = ImageSearchConfig( - encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/" + encoder="sentence-transformers/clip-ViT-B-32", + model_directory=model_dir / "image/", + encoder_type=None, ) return search_config @pytest.fixture(scope="session") -def content_config(tmp_path_factory, search_config: SearchConfig): +def search_models(search_config: SearchConfig): + search_models = SearchModels() + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + search_models.image_search = image_search.initialize_model(search_config.image) + + return search_models + + +@pytest.fixture(scope="session") +def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig): content_dir = tmp_path_factory.mktemp("content") # Generate Image Embeddings from Test Images content_config = ContentConfig() content_config.image = ImageContentConfig( + input_filter=None, input_directories=["tests/data/images"], embeddings_file=content_dir.joinpath("image_embeddings.pt"), batch_size=1, use_xmp_metadata=False, ) - image_search.setup(content_config.image, search_config.image, regenerate=False) + image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False) # Generate Notes Embeddings from Test Notes content_config.org = TextContentConfig( @@ -80,7 +95,9 @@ def content_config(tmp_path_factory, search_config: SearchConfig): ) filters = [DateFilter(), WordFilter(), FileFilter()] - text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) + text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters + ) content_config.plugins = { "plugin1": TextContentConfig( @@ -106,7 +123,11 @@ def content_config(tmp_path_factory, search_config: SearchConfig): filters = [DateFilter(), WordFilter(), FileFilter()] text_search.setup( - JsonlToJsonl, content_config.plugins["plugin1"], search_config.asymmetric, regenerate=False, filters=filters + JsonlToJsonl, + content_config.plugins["plugin1"], + search_models.text_search.bi_encoder, + regenerate=False, + filters=filters, ) return content_config @@ -157,8 +178,13 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p # Index Markdown Content for Search filters = [DateFilter(), WordFilter(), FileFilter()] - state.model.markdown_search = text_search.setup( - MarkdownToJsonl, md_content_config.markdown, search_config.asymmetric, regenerate=False, filters=filters + state.search_models.text_search = text_search.initialize_model(search_config.asymmetric) + state.content_index.markdown = text_search.setup( + MarkdownToJsonl, + md_content_config.markdown, + state.search_models.text_search.bi_encoder, + regenerate=False, + filters=filters, ) # Initialize Processor from Config @@ -175,8 +201,14 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor state.SearchType = configure_search_types(state.config) # These lines help us Mock the Search models for these search types - state.model.org_search = {} - state.model.image_search = {} + state.search_models.text_search = text_search.initialize_model(search_config.asymmetric) + state.search_models.image_search = image_search.initialize_model(search_config.image) + state.content_index.org = text_search.setup( + OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False + ) + state.content_index.image = image_search.setup( + content_config.image, state.search_models.image_search, regenerate=False + ) configure_routes(app) return TestClient(app) diff --git a/tests/test_client.py b/tests/test_client.py index 81955f39..d86bdd90 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,7 +11,8 @@ from fastapi.testclient import TestClient from khoj.main import app from khoj.configure import configure_routes, configure_search_types from khoj.utils import state -from khoj.utils.state import model, config +from khoj.utils.config import SearchModels +from khoj.utils.state import search_models, content_index, config from khoj.search_type import text_search, image_search from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl @@ -143,7 +144,10 @@ def test_get_configured_types_with_no_content_config(): # ---------------------------------------------------------------------------------------------------- def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) + search_models.image_search = image_search.initialize_model(search_config.image) + content_index.image = image_search.setup( + content_config.image, search_models.image_search.image_encoder, regenerate=False + ) query_expected_image_pairs = [ ("kitten", "kitten_park.jpg"), ("a horse and dog on a leash", "horse_dog.jpg"), @@ -166,7 +170,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear # ---------------------------------------------------------------------------------------------------- def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.org_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + content_index.org = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False + ) user_query = quote("How to git install application?") # Act @@ -183,8 +190,9 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter(), FileFilter()] - model.org_search = text_search.setup( - OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + content_index.org = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters ) user_query = quote('+"Emacs" file:"*.org"') @@ -202,8 +210,9 @@ def test_notes_search_with_only_filters(client, content_config: ContentConfig, s def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter()] - model.org_search = text_search.setup( - OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + content_index.org = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search, regenerate=False, filters=filters ) user_query = quote('How to git install application? +"Emacs"') @@ -221,8 +230,9 @@ def test_notes_search_with_include_filter(client, content_config: ContentConfig, def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig): # Arrange filters = [WordFilter()] - model.org_search = text_search.setup( - OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + content_index.org = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters ) user_query = quote('How to git install application? -"clone"') diff --git a/tests/test_image_search.py b/tests/test_image_search.py index e4f08d35..82617ab3 100644 --- a/tests/test_image_search.py +++ b/tests/test_image_search.py @@ -5,9 +5,10 @@ from PIL import Image # External Packages import pytest +from khoj.utils.config import SearchModels # Internal Packages -from khoj.utils.state import model +from khoj.utils.state import content_index, search_models from khoj.utils.constants import web_directory from khoj.search_type import image_search from khoj.utils.helpers import resolve_absolute_path @@ -16,10 +17,12 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig # Test # ---------------------------------------------------------------------------------------------------- -def test_image_search_setup(content_config: ContentConfig, search_config: SearchConfig): +def test_image_search_setup(content_config: ContentConfig, search_models: SearchModels): # Act # Regenerate image search embeddings during image setup - image_search_model = image_search.setup(content_config.image, search_config.image, regenerate=True) + image_search_model = image_search.setup( + content_config.image, search_models.image_search.image_encoder, regenerate=True + ) # Assert assert len(image_search_model.image_names) == 3 @@ -54,8 +57,11 @@ def test_image_metadata(content_config: ContentConfig): @pytest.mark.anyio async def test_image_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange + search_models.image_search = image_search.initialize_model(search_config.image) + content_index.image = image_search.setup( + content_config.image, search_models.image_search.image_encoder, regenerate=False + ) output_directory = resolve_absolute_path(web_directory) - model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) query_expected_image_pairs = [ ("kitten", "kitten_park.jpg"), ("horse and dog in a farm", "horse_dog.jpg"), @@ -64,11 +70,13 @@ async def test_image_search(content_config: ContentConfig, search_config: Search # Act for query, expected_image_name in query_expected_image_pairs: - hits = await image_search.query(query, count=1, model=model.image_search) + hits = await image_search.query( + query, count=1, search_model=search_models.image_search, content=content_index.image + ) results = image_search.collate_results( hits, - model.image_search.image_names, + content_index.image.image_names, output_directory=output_directory, image_files_url="/static/images", count=1, @@ -90,7 +98,10 @@ async def test_image_search(content_config: ContentConfig, search_config: Search @pytest.mark.anyio async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog): # Arrange - model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) + search_models.image_search = image_search.initialize_model(search_config.image) + content_index.image = image_search.setup( + content_config.image, search_models.image_search.image_encoder, regenerate=False + ) max_words_supported = 10 query = " ".join(["hello"] * 100) truncated_query = " ".join(["hello"] * max_words_supported) @@ -98,7 +109,9 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc # Act try: with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): - await image_search.query(query, count=1, model=model.image_search) + await image_search.query( + query, count=1, search_model=search_models.image_search, content=content_index.image + ) # Assert except RuntimeError as e: if "The size of tensor a (102) must match the size of tensor b (77)" in str(e): @@ -110,8 +123,11 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc @pytest.mark.anyio async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog): # Arrange + search_models.image_search = image_search.initialize_model(search_config.image) + content_index.image = image_search.setup( + content_config.image, search_models.image_search.image_encoder, regenerate=False + ) output_directory = resolve_absolute_path(web_directory) - model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) image_directory = content_config.image.input_directories[0] query = f"file:{image_directory.joinpath('kitten_park.jpg')}" @@ -119,11 +135,13 @@ async def test_image_search_by_filepath(content_config: ContentConfig, search_co # Act with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): - hits = await image_search.query(query, count=1, model=model.image_search) + hits = await image_search.query( + query, count=1, search_model=search_models.image_search, content=content_index.image + ) results = image_search.collate_results( hits, - model.image_search.image_names, + content_index.image.image_names, output_directory=output_directory, image_files_url="/static/images", count=1, diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 69f58645..c18a4c42 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -5,9 +5,10 @@ import os # External Packages import pytest +from khoj.utils.config import SearchModels # Internal Packages -from khoj.utils.state import model +from khoj.utils.state import content_index, search_models from khoj.search_type import text_search from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl @@ -41,10 +42,12 @@ def test_asymmetric_setup_with_empty_file_raises_error( # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): +def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels): # Act # Regenerate notes embeddings during asymmetric setup - notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True + ) # Assert assert len(notes_model.entries) == 10 @@ -52,18 +55,18 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo # ---------------------------------------------------------------------------------------------------- -def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_config: SearchConfig, caplog): +def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog): # Arrange caplog.set_level(logging.INFO, logger="khoj") # Act # Generate initial notes embeddings during asymmetric setup - text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True) initial_logs = caplog.text caplog.clear() # Clear logs # Run asymmetric setup again with no changes to data source. Ensure index is not updated - text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) + text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False) final_logs = caplog.text # Assert @@ -75,11 +78,16 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi @pytest.mark.anyio async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + search_models.text_search = text_search.initialize_model(search_config.asymmetric) + content_index.org = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True + ) query = "How to git install application?" # Act - hits, entries = await text_search.query(query, model=model.notes_search, rank_results=True) + hits, entries = await text_search.query( + query, search_model=search_models.text_search, content=content_index.org, rank_results=True + ) results = text_search.collate_results(hits, entries, count=1) @@ -90,7 +98,7 @@ async def test_asymmetric_search(content_config: ContentConfig, search_config: S # ---------------------------------------------------------------------------------------------------- -def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig): +def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels): # Arrange # Insert org-mode entry with size exceeding max token limit to new org file max_tokens = 256 @@ -103,7 +111,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent # Act # reload embeddings, entries, notes model after adding new org-mode file initial_notes_model = text_search.setup( - OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False + OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False ) # Assert @@ -113,9 +121,11 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent # ---------------------------------------------------------------------------------------------------- -def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path): +def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path): # Arrange - initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + initial_notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True + ) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 @@ -127,12 +137,14 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC # regenerate notes jsonl, model embeddings and model to include entry from new file regenerated_notes_model = text_search.setup( - OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True ) # Act # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files - initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False + ) # Assert assert len(regenerated_notes_model.entries) == 11 @@ -149,9 +161,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC # ---------------------------------------------------------------------------------------------------- -def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path): +def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path): # Arrange - initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) + initial_notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True + ) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 @@ -163,7 +177,9 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search # Act # update embeddings, entries with the newly added note content_config.org.input_files = [f"{new_org_file}"] - initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model = text_search.setup( + OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False + ) # Assert # verify new entry added in updated embeddings, entries @@ -177,10 +193,12 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search # ---------------------------------------------------------------------------------------------------- @pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set") -def test_asymmetric_setup_github(content_config: ContentConfig, search_config: SearchConfig): +def test_asymmetric_setup_github(content_config: ContentConfig, search_models: SearchModels): # Act # Regenerate github embeddings to test asymmetric setup without caching - github_model = text_search.setup(GithubToJsonl, content_config.github, search_config.asymmetric, regenerate=True) + github_model = text_search.setup( + GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True + ) # Assert assert len(github_model.entries) > 1