Update Tests to setup both content_index, search_models before testing

This is required by the updated structure of Khoj setup

- Add content_config pytest fixture, pass bi_encoder from
  search_models.[text|image]_search
This commit is contained in:
Debanjum Singh Solanky
2023-07-14 01:19:38 -07:00
parent 86e2bec9a0
commit b9fb656657
4 changed files with 126 additions and 48 deletions

View File

@@ -5,9 +5,10 @@ import os
# External Packages
import pytest
from khoj.utils.config import SearchModels
# Internal Packages
from khoj.utils.state import model
from khoj.utils.state import content_index, search_models
from khoj.search_type import text_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
@@ -41,10 +42,12 @@ def test_asymmetric_setup_with_empty_file_raises_error(
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels):
# Act
# Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
# Assert
assert len(notes_model.entries) == 10
@@ -52,18 +55,18 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
# ----------------------------------------------------------------------------------------------------
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_config: SearchConfig, caplog):
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog):
# Arrange
caplog.set_level(logging.INFO, logger="khoj")
# Act
# Generate initial notes embeddings during asymmetric setup
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
initial_logs = caplog.text
caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
final_logs = caplog.text
# Assert
@@ -75,11 +78,16 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
@pytest.mark.anyio
async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
query = "How to git install application?"
# Act
hits, entries = await text_search.query(query, model=model.notes_search, rank_results=True)
hits, entries = await text_search.query(
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
)
results = text_search.collate_results(hits, entries, count=1)
@@ -90,7 +98,7 @@ async def test_asymmetric_search(content_config: ContentConfig, search_config: S
# ----------------------------------------------------------------------------------------------------
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig):
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256
@@ -103,7 +111,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# Act
# reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False
OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
@@ -113,9 +121,11 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path):
def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
@@ -127,12 +137,14 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
# Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
assert len(regenerated_notes_model.entries) == 11
@@ -149,9 +161,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# ----------------------------------------------------------------------------------------------------
def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path):
def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
@@ -163,7 +177,9 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# Act
# update embeddings, entries with the newly added note
content_config.org.input_files = [f"{new_org_file}"]
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
# Assert
# verify new entry added in updated embeddings, entries
@@ -177,10 +193,12 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
def test_asymmetric_setup_github(content_config: ContentConfig, search_config: SearchConfig):
def test_asymmetric_setup_github(content_config: ContentConfig, search_models: SearchModels):
# Act
# Regenerate github embeddings to test asymmetric setup without caching
github_model = text_search.setup(GithubToJsonl, content_config.github, search_config.asymmetric, regenerate=True)
github_model = text_search.setup(
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
)
# Assert
assert len(github_model.entries) > 1