Update Tests to setup both content_index, search_models before testing

This is required by the updated structure of Khoj setup

- Add content_config pytest fixture, pass bi_encoder from
  search_models.[text|image]_search
This commit is contained in:
Debanjum Singh Solanky
2023-07-14 01:19:38 -07:00
parent 86e2bec9a0
commit b9fb656657
4 changed files with 126 additions and 48 deletions

View File

@@ -10,6 +10,7 @@ from khoj.main import app
from khoj.configure import configure_processor, configure_routes, configure_search_types from khoj.configure import configure_processor, configure_routes, configure_search_types
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.search_type import image_search, text_search from khoj.search_type import image_search, text_search
from khoj.utils.config import ImageContent, SearchModels, TextContent
from khoj.utils.helpers import resolve_absolute_path from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
ContentConfig, ContentConfig,
@@ -41,35 +42,49 @@ def search_config() -> SearchConfig:
encoder="sentence-transformers/all-MiniLM-L6-v2", encoder="sentence-transformers/all-MiniLM-L6-v2",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2", cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "symmetric/", model_directory=model_dir / "symmetric/",
encoder_type=None,
) )
search_config.asymmetric = TextSearchConfig( search_config.asymmetric = TextSearchConfig(
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1", encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2", cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "asymmetric/", model_directory=model_dir / "asymmetric/",
encoder_type=None,
) )
search_config.image = ImageSearchConfig( search_config.image = ImageSearchConfig(
encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/" encoder="sentence-transformers/clip-ViT-B-32",
model_directory=model_dir / "image/",
encoder_type=None,
) )
return search_config return search_config
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_config: SearchConfig): def search_models(search_config: SearchConfig):
search_models = SearchModels()
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
@pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
content_dir = tmp_path_factory.mktemp("content") content_dir = tmp_path_factory.mktemp("content")
# Generate Image Embeddings from Test Images # Generate Image Embeddings from Test Images
content_config = ContentConfig() content_config = ContentConfig()
content_config.image = ImageContentConfig( content_config.image = ImageContentConfig(
input_filter=None,
input_directories=["tests/data/images"], input_directories=["tests/data/images"],
embeddings_file=content_dir.joinpath("image_embeddings.pt"), embeddings_file=content_dir.joinpath("image_embeddings.pt"),
batch_size=1, batch_size=1,
use_xmp_metadata=False, use_xmp_metadata=False,
) )
image_search.setup(content_config.image, search_config.image, regenerate=False) image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
# Generate Notes Embeddings from Test Notes # Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig( content_config.org = TextContentConfig(
@@ -80,7 +95,9 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
) )
filters = [DateFilter(), WordFilter(), FileFilter()] filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
)
content_config.plugins = { content_config.plugins = {
"plugin1": TextContentConfig( "plugin1": TextContentConfig(
@@ -106,7 +123,11 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
filters = [DateFilter(), WordFilter(), FileFilter()] filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup( text_search.setup(
JsonlToJsonl, content_config.plugins["plugin1"], search_config.asymmetric, regenerate=False, filters=filters JsonlToJsonl,
content_config.plugins["plugin1"],
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
) )
return content_config return content_config
@@ -157,8 +178,13 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
# Index Markdown Content for Search # Index Markdown Content for Search
filters = [DateFilter(), WordFilter(), FileFilter()] filters = [DateFilter(), WordFilter(), FileFilter()]
state.model.markdown_search = text_search.setup( state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
MarkdownToJsonl, md_content_config.markdown, search_config.asymmetric, regenerate=False, filters=filters state.content_index.markdown = text_search.setup(
MarkdownToJsonl,
md_content_config.markdown,
state.search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
) )
# Initialize Processor from Config # Initialize Processor from Config
@@ -175,8 +201,14 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor
state.SearchType = configure_search_types(state.config) state.SearchType = configure_search_types(state.config)
# These lines help us Mock the Search models for these search types # These lines help us Mock the Search models for these search types
state.model.org_search = {} state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.model.image_search = {} state.search_models.image_search = image_search.initialize_model(search_config.image)
state.content_index.org = text_search.setup(
OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
configure_routes(app) configure_routes(app)
return TestClient(app) return TestClient(app)

View File

@@ -11,7 +11,8 @@ from fastapi.testclient import TestClient
from khoj.main import app from khoj.main import app
from khoj.configure import configure_routes, configure_search_types from khoj.configure import configure_routes, configure_search_types
from khoj.utils import state from khoj.utils import state
from khoj.utils.state import model, config from khoj.utils.config import SearchModels
from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
@@ -143,7 +144,10 @@ def test_get_configured_types_with_no_content_config():
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig): def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
query_expected_image_pairs = [ query_expected_image_pairs = [
("kitten", "kitten_park.jpg"), ("kitten", "kitten_park.jpg"),
("a horse and dog on a leash", "horse_dog.jpg"), ("a horse and dog on a leash", "horse_dog.jpg"),
@@ -166,7 +170,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig): def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.org_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
user_query = quote("How to git install application?") user_query = quote("How to git install application?")
# Act # Act
@@ -183,8 +190,9 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_only_filters(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
filters = [WordFilter(), FileFilter()] filters = [WordFilter(), FileFilter()]
model.org_search = text_search.setup( search_models.text_search = text_search.initialize_model(search_config.asymmetric)
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
) )
user_query = quote('+"Emacs" file:"*.org"') user_query = quote('+"Emacs" file:"*.org"')
@@ -202,8 +210,9 @@ def test_notes_search_with_only_filters(client, content_config: ContentConfig, s
def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_include_filter(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
filters = [WordFilter()] filters = [WordFilter()]
model.org_search = text_search.setup( search_models.text_search = text_search.initialize_model(search_config.asymmetric)
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search, regenerate=False, filters=filters
) )
user_query = quote('How to git install application? +"Emacs"') user_query = quote('How to git install application? +"Emacs"')
@@ -221,8 +230,9 @@ def test_notes_search_with_include_filter(client, content_config: ContentConfig,
def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_exclude_filter(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
filters = [WordFilter()] filters = [WordFilter()]
model.org_search = text_search.setup( search_models.text_search = text_search.initialize_model(search_config.asymmetric)
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
) )
user_query = quote('How to git install application? -"clone"') user_query = quote('How to git install application? -"clone"')

View File

@@ -5,9 +5,10 @@ from PIL import Image
# External Packages # External Packages
import pytest import pytest
from khoj.utils.config import SearchModels
# Internal Packages # Internal Packages
from khoj.utils.state import model from khoj.utils.state import content_index, search_models
from khoj.utils.constants import web_directory from khoj.utils.constants import web_directory
from khoj.search_type import image_search from khoj.search_type import image_search
from khoj.utils.helpers import resolve_absolute_path from khoj.utils.helpers import resolve_absolute_path
@@ -16,10 +17,12 @@ from khoj.utils.rawconfig import ContentConfig, SearchConfig
# Test # Test
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_image_search_setup(content_config: ContentConfig, search_config: SearchConfig): def test_image_search_setup(content_config: ContentConfig, search_models: SearchModels):
# Act # Act
# Regenerate image search embeddings during image setup # Regenerate image search embeddings during image setup
image_search_model = image_search.setup(content_config.image, search_config.image, regenerate=True) image_search_model = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=True
)
# Assert # Assert
assert len(image_search_model.image_names) == 3 assert len(image_search_model.image_names) == 3
@@ -54,8 +57,11 @@ def test_image_metadata(content_config: ContentConfig):
@pytest.mark.anyio @pytest.mark.anyio
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig): async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
output_directory = resolve_absolute_path(web_directory) output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
query_expected_image_pairs = [ query_expected_image_pairs = [
("kitten", "kitten_park.jpg"), ("kitten", "kitten_park.jpg"),
("horse and dog in a farm", "horse_dog.jpg"), ("horse and dog in a farm", "horse_dog.jpg"),
@@ -64,11 +70,13 @@ async def test_image_search(content_config: ContentConfig, search_config: Search
# Act # Act
for query, expected_image_name in query_expected_image_pairs: for query, expected_image_name in query_expected_image_pairs:
hits = await image_search.query(query, count=1, model=model.image_search) hits = await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
results = image_search.collate_results( results = image_search.collate_results(
hits, hits,
model.image_search.image_names, content_index.image.image_names,
output_directory=output_directory, output_directory=output_directory,
image_files_url="/static/images", image_files_url="/static/images",
count=1, count=1,
@@ -90,7 +98,10 @@ async def test_image_search(content_config: ContentConfig, search_config: Search
@pytest.mark.anyio @pytest.mark.anyio
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog): async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange # Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
max_words_supported = 10 max_words_supported = 10
query = " ".join(["hello"] * 100) query = " ".join(["hello"] * 100)
truncated_query = " ".join(["hello"] * max_words_supported) truncated_query = " ".join(["hello"] * max_words_supported)
@@ -98,7 +109,9 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc
# Act # Act
try: try:
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
await image_search.query(query, count=1, model=model.image_search) await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
# Assert # Assert
except RuntimeError as e: except RuntimeError as e:
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e): if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
@@ -110,8 +123,11 @@ async def test_image_search_query_truncated(content_config: ContentConfig, searc
@pytest.mark.anyio @pytest.mark.anyio
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog): async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange # Arrange
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
)
output_directory = resolve_absolute_path(web_directory) output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
image_directory = content_config.image.input_directories[0] image_directory = content_config.image.input_directories[0]
query = f"file:{image_directory.joinpath('kitten_park.jpg')}" query = f"file:{image_directory.joinpath('kitten_park.jpg')}"
@@ -119,11 +135,13 @@ async def test_image_search_by_filepath(content_config: ContentConfig, search_co
# Act # Act
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"): with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
hits = await image_search.query(query, count=1, model=model.image_search) hits = await image_search.query(
query, count=1, search_model=search_models.image_search, content=content_index.image
)
results = image_search.collate_results( results = image_search.collate_results(
hits, hits,
model.image_search.image_names, content_index.image.image_names,
output_directory=output_directory, output_directory=output_directory,
image_files_url="/static/images", image_files_url="/static/images",
count=1, count=1,

View File

@@ -5,9 +5,10 @@ import os
# External Packages # External Packages
import pytest import pytest
from khoj.utils.config import SearchModels
# Internal Packages # Internal Packages
from khoj.utils.state import model from khoj.utils.state import content_index, search_models
from khoj.search_type import text_search from khoj.search_type import text_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
@@ -41,10 +42,12 @@ def test_asymmetric_setup_with_empty_file_raises_error(
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): def test_asymmetric_setup(content_config: ContentConfig, search_models: SearchModels):
# Act # Act
# Regenerate notes embeddings during asymmetric setup # Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
# Assert # Assert
assert len(notes_model.entries) == 10 assert len(notes_model.entries) == 10
@@ -52,18 +55,18 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_config: SearchConfig, caplog): def test_text_content_index_only_updates_on_changes(content_config: ContentConfig, search_models: SearchModels, caplog):
# Arrange # Arrange
caplog.set_level(logging.INFO, logger="khoj") caplog.set_level(logging.INFO, logger="khoj")
# Act # Act
# Generate initial notes embeddings during asymmetric setup # Generate initial notes embeddings during asymmetric setup
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
initial_logs = caplog.text initial_logs = caplog.text
caplog.clear() # Clear logs caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated # Run asymmetric setup again with no changes to data source. Ensure index is not updated
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) text_search.setup(OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
final_logs = caplog.text final_logs = caplog.text
# Assert # Assert
@@ -75,11 +78,16 @@ def test_text_content_index_only_updates_on_changes(content_config: ContentConfi
@pytest.mark.anyio @pytest.mark.anyio
async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): async def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.notes_search = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
query = "How to git install application?" query = "How to git install application?"
# Act # Act
hits, entries = await text_search.query(query, model=model.notes_search, rank_results=True) hits, entries = await text_search.query(
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
)
results = text_search.collate_results(hits, entries, count=1) results = text_search.collate_results(hits, entries, count=1)
@@ -90,7 +98,7 @@ async def test_asymmetric_search(content_config: ContentConfig, search_config: S
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig): def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
# Arrange # Arrange
# Insert org-mode entry with size exceeding max token limit to new org file # Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256 max_tokens = 256
@@ -103,7 +111,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# Act # Act
# reload embeddings, entries, notes model after adding new org-mode file # reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup( initial_notes_model = text_search.setup(
OrgToJsonl, org_config_with_only_new_file, search_config.asymmetric, regenerate=False OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
) )
# Assert # Assert
@@ -113,9 +121,11 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path): def test_asymmetric_reload(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange # Arrange
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10 assert len(initial_notes_model.corpus_embeddings) == 10
@@ -127,12 +137,14 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# regenerate notes jsonl, model embeddings and model to include entry from new file # regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup( regenerated_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
) )
# Act # Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
# Assert # Assert
assert len(regenerated_notes_model.entries) == 11 assert len(regenerated_notes_model.entries) == 11
@@ -149,9 +161,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_incremental_update(content_config: ContentConfig, search_config: SearchConfig, new_org_file: Path): def test_incremental_update(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
# Arrange # Arrange
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=True) initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10 assert len(initial_notes_model.corpus_embeddings) == 10
@@ -163,7 +177,9 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# Act # Act
# update embeddings, entries with the newly added note # update embeddings, entries with the newly added note
content_config.org.input_files = [f"{new_org_file}"] content_config.org.input_files = [f"{new_org_file}"]
initial_notes_model = text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False) initial_notes_model = text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
# Assert # Assert
# verify new entry added in updated embeddings, entries # verify new entry added in updated embeddings, entries
@@ -177,10 +193,12 @@ def test_incremental_update(content_config: ContentConfig, search_config: Search
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set") @pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
def test_asymmetric_setup_github(content_config: ContentConfig, search_config: SearchConfig): def test_asymmetric_setup_github(content_config: ContentConfig, search_models: SearchModels):
# Act # Act
# Regenerate github embeddings to test asymmetric setup without caching # Regenerate github embeddings to test asymmetric setup without caching
github_model = text_search.setup(GithubToJsonl, content_config.github, search_config.asymmetric, regenerate=True) github_model = text_search.setup(
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
)
# Assert # Assert
assert len(github_model.entries) > 1 assert len(github_model.entries) > 1