Update Tests to setup both content_index, search_models before testing

This is required by the updated structure of Khoj setup

- Add content_config pytest fixture, pass bi_encoder from
  search_models.[text|image]_search
This commit is contained in:
Debanjum Singh Solanky
2023-07-14 01:19:38 -07:00
parent 86e2bec9a0
commit b9fb656657
4 changed files with 126 additions and 48 deletions

View File

@@ -10,6 +10,7 @@ from khoj.main import app
from khoj.configure import configure_processor, configure_routes, configure_search_types
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils.config import ImageContent, SearchModels, TextContent
from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import (
ContentConfig,
@@ -41,35 +42,49 @@ def search_config() -> SearchConfig:
encoder="sentence-transformers/all-MiniLM-L6-v2",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "symmetric/",
encoder_type=None,
)
search_config.asymmetric = TextSearchConfig(
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "asymmetric/",
encoder_type=None,
)
search_config.image = ImageSearchConfig(
encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/"
encoder="sentence-transformers/clip-ViT-B-32",
model_directory=model_dir / "image/",
encoder_type=None,
)
return search_config
@pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_config: SearchConfig):
def search_models(search_config: SearchConfig):
search_models = SearchModels()
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
@pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
content_dir = tmp_path_factory.mktemp("content")
# Generate Image Embeddings from Test Images
content_config = ContentConfig()
content_config.image = ImageContentConfig(
input_filter=None,
input_directories=["tests/data/images"],
embeddings_file=content_dir.joinpath("image_embeddings.pt"),
batch_size=1,
use_xmp_metadata=False,
)
image_search.setup(content_config.image, search_config.image, regenerate=False)
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
# Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig(
@@ -80,7 +95,9 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
)
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(OrgToJsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
text_search.setup(
OrgToJsonl, content_config.org, search_models.text_search.bi_encoder, regenerate=False, filters=filters
)
content_config.plugins = {
"plugin1": TextContentConfig(
@@ -106,7 +123,11 @@ def content_config(tmp_path_factory, search_config: SearchConfig):
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(
JsonlToJsonl, content_config.plugins["plugin1"], search_config.asymmetric, regenerate=False, filters=filters
JsonlToJsonl,
content_config.plugins["plugin1"],
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
return content_config
@@ -157,8 +178,13 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
# Index Markdown Content for Search
filters = [DateFilter(), WordFilter(), FileFilter()]
state.model.markdown_search = text_search.setup(
MarkdownToJsonl, md_content_config.markdown, search_config.asymmetric, regenerate=False, filters=filters
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.content_index.markdown = text_search.setup(
MarkdownToJsonl,
md_content_config.markdown,
state.search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
# Initialize Processor from Config
@@ -175,8 +201,14 @@ def client(content_config: ContentConfig, search_config: SearchConfig, processor
state.SearchType = configure_search_types(state.config)
# These lines help us Mock the Search models for these search types
state.model.org_search = {}
state.model.image_search = {}
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
state.content_index.org = text_search.setup(
OrgToJsonl, content_config.org, state.search_models.text_search.bi_encoder, regenerate=False
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
configure_routes(app)
return TestClient(app)