Merge pull request #18 from debanjum/deb/save-models-to-disk-on-first-run

Save Search Models to Disk on First Run

## Why
  - Improve application startup time
  - Startup application and perform semantic search even if user offline
  - Use search model config in YAML file for all search types (asymmetric, symmetric, image)

## Details
  - Load search models from disk when available
  - Use search model config specified in YAML file
  - Add search config for Symmetric Search used by Ledger/Beancount transaction search
This commit is contained in:
Debanjum
2022-01-14 17:30:46 -08:00
committed by GitHub
12 changed files with 183 additions and 75 deletions

View File

@@ -24,12 +24,19 @@ content-type:
embeddings-file: "tests/data/.song_embeddings.pt" embeddings-file: "tests/data/.song_embeddings.pt"
search-type: search-type:
symmetric:
encoder: "sentence-transformers/paraphrase-MiniLM-L6-v2"
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
model_directory: "tests/data/.symmetric"
asymmetric: asymmetric:
encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3" encoder: "sentence-transformers/msmarco-MiniLM-L-6-v3"
cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2" cross-encoder: "cross-encoder/ms-marco-MiniLM-L-6-v2"
model_directory: "tests/data/.asymmetric"
image: image:
encoder: "clip-ViT-B-32" encoder: "clip-ViT-B-32"
model_directory: "tests/data/.image_encoder"
processor: processor:
conversation: conversation:

View File

@@ -130,22 +130,22 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None
# Initialize Org Notes Search # Initialize Org Notes Search
if (t == SearchType.Notes or t == None) and config.content_type.org: if (t == SearchType.Notes or t == None) and config.content_type.org:
# Extract Entries, Generate Notes Embeddings # Extract Entries, Generate Notes Embeddings
model.notes_search = asymmetric.setup(config.content_type.org, regenerate=regenerate, verbose=verbose) model.notes_search = asymmetric.setup(config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
# Initialize Org Music Search # Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music: if (t == SearchType.Music or t == None) and config.content_type.music:
# Extract Entries, Generate Music Embeddings # Extract Entries, Generate Music Embeddings
model.music_search = asymmetric.setup(config.content_type.music, regenerate=regenerate, verbose=verbose) model.music_search = asymmetric.setup(config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate, verbose=verbose)
# Initialize Ledger Search # Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger: if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings # Extract Entries, Generate Ledger Embeddings
model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, regenerate=regenerate, verbose=verbose) model.ledger_search = symmetric_ledger.setup(config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, verbose=verbose)
# Initialize Image Search # Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image: if (t == SearchType.Image or t == None) and config.content_type.image:
# Extract Entries, Generate Image Embeddings # Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup(config.content_type.image, regenerate=regenerate, verbose=verbose) model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate, verbose=verbose)
return model return model

View File

@@ -12,18 +12,31 @@ import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages # Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.utils.config import TextSearchModel from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfig from src.utils.rawconfig import AsymmetricConfig, TextSearchConfig
def initialize_model(): def initialize_model(search_config: AsymmetricConfig):
"Initialize model for assymetric semantic search. That is, where query smaller than results" "Initialize model for assymetric semantic search. That is, where query smaller than results"
torch.set_num_threads(4) torch.set_num_threads(4)
bi_encoder = SentenceTransformer('sentence-transformers/msmarco-MiniLM-L-6-v3') # The bi-encoder encodes all entries to use for semantic search
top_k = 30 # Number of entries we want to retrieve with the bi-encoder # Number of entries we want to retrieve with the bi-encoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality top_k = 30
# The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = SentenceTransformer)
# The cross-encoder re-ranks the results to improve quality
cross_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.cross_encoder,
model_type = CrossEncoder)
return bi_encoder, cross_encoder, top_k return bi_encoder, cross_encoder, top_k
@@ -149,9 +162,9 @@ def collate_results(hits, entries, count=5):
in hits[0:count]] in hits[0:count]]
def setup(config: TextSearchConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel: def setup(config: TextSearchConfig, search_config: AsymmetricConfig, regenerate: bool, verbose: bool=False) -> TextSearchModel:
# Initialize Model # Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model() bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in Org-Mode files to (compressed) JSONL formatted file # Map notes in Org-Mode files to (compressed) JSONL formatted file
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate: if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:

View File

@@ -10,16 +10,22 @@ from tqdm import trange
import torch import torch
# Internal Packages # Internal Packages
from src.utils.helpers import resolve_absolute_path from src.utils.helpers import resolve_absolute_path, load_model
import src.utils.exiftool as exiftool import src.utils.exiftool as exiftool
from src.utils.config import ImageSearchModel from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageSearchConfig from src.utils.rawconfig import ImageSearchConfig, ImageSearchTypeConfig
def initialize_model(): def initialize_model(search_config: ImageSearchTypeConfig):
# Initialize Model # Initialize Model
torch.set_num_threads(4) torch.set_num_threads(4)
encoder = SentenceTransformer('sentence-transformers/clip-ViT-B-32') #Load the CLIP model
# Load the CLIP model
encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = SentenceTransformer)
return encoder return encoder
@@ -154,9 +160,9 @@ def collate_results(hits, image_names, image_directory, count=5):
in hits[0:count]] in hits[0:count]]
def setup(config: ImageSearchConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel: def setup(config: ImageSearchConfig, search_config: ImageSearchTypeConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel:
# Initialize Model # Initialize Model
encoder = initialize_model() encoder = initialize_model(search_config)
# Extract Entries # Extract Entries
image_directory = resolve_absolute_path(config.input_directory, strict=True) image_directory = resolve_absolute_path(config.input_directory, strict=True)

View File

@@ -10,18 +10,31 @@ import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages # Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl from src.processor.ledger.beancount_to_jsonl import beancount_to_jsonl
from src.utils.config import TextSearchModel from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfig from src.utils.rawconfig import SymmetricConfig, TextSearchConfig
def initialize_model(): def initialize_model(search_config: SymmetricConfig):
"Initialize model for symmetric semantic search. That is, where query of similar size to results" "Initialize model for symmetric semantic search. That is, where query of similar size to results"
torch.set_num_threads(4) torch.set_num_threads(4)
bi_encoder = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') # The encoder encodes all entries to use for semantic search
top_k = 30 # Number of entries we want to retrieve with the bi-encoder # Number of entries we want to retrieve with the bi-encoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # The cross-encoder re-ranks the results to improve quality top_k = 30
# The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = SentenceTransformer)
# The cross-encoder re-ranks the results to improve quality
cross_encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.cross_encoder,
model_type = CrossEncoder)
return bi_encoder, cross_encoder, top_k return bi_encoder, cross_encoder, top_k
@@ -141,9 +154,9 @@ def collate_results(hits, entries, count=5):
in hits[0:count]] in hits[0:count]]
def setup(config: TextSearchConfig, regenerate: bool, verbose: bool) -> TextSearchModel: def setup(config: TextSearchConfig, search_config: SymmetricConfig, regenerate: bool, verbose: bool) -> TextSearchModel:
# Initialize Model # Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model() bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in Org-Mode files to (compressed) JSONL formatted file # Map notes in Org-Mode files to (compressed) JSONL formatted file
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate: if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:

View File

@@ -77,14 +77,22 @@ default_config = {
}, },
'search-type': 'search-type':
{ {
'symmetric':
{
'encoder': "sentence-transformers/paraphrase-MiniLM-L6-v2",
'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2",
'model_directory': None
},
'asymmetric': 'asymmetric':
{ {
'encoder': "sentence-transformers/msmarco-MiniLM-L-6-v3", 'encoder': "sentence-transformers/msmarco-MiniLM-L-6-v3",
'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2" 'cross-encoder': "cross-encoder/ms-marco-MiniLM-L-6-v2",
'model_directory': None
}, },
'image': 'image':
{ {
'encoder': "clip-ViT-B-32" 'encoder': "clip-ViT-B-32",
'model_directory': None
}, },
}, },
'processor': 'processor':

View File

@@ -1,4 +1,6 @@
# Standard Packages
import pathlib import pathlib
from os.path import join
def is_none_or_empty(item): def is_none_or_empty(item):
@@ -32,3 +34,20 @@ def merge_dicts(priority_dict, default_dict):
if k not in priority_dict: if k not in priority_dict:
merged_dict[k] = default_dict[k] merged_dict[k] = default_dict[k]
return merged_dict return merged_dict
def load_model(model_name, model_dir, model_type):
"Load model from disk or huggingface"
# Construct model path
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
# Load model from model_path if it exists there
if model_path is not None and resolve_absolute_path(model_path).exists():
model = model_type(get_absolute_path(model_path))
# Else load the model from the model_name
else:
model = model_type(model_name)
if model_path is not None:
model.save(model_path)
return model

View File

@@ -37,15 +37,23 @@ class ContentTypeConfig(ConfigBase):
image: Optional[ImageSearchConfig] image: Optional[ImageSearchConfig]
music: Optional[TextSearchConfig] music: Optional[TextSearchConfig]
class SymmetricConfig(ConfigBase):
encoder: Optional[str]
cross_encoder: Optional[str]
model_directory: Optional[Path]
class AsymmetricConfig(ConfigBase): class AsymmetricConfig(ConfigBase):
encoder: Optional[str] encoder: Optional[str]
cross_encoder: Optional[str] cross_encoder: Optional[str]
model_directory: Optional[Path]
class ImageSearchTypeConfig(ConfigBase): class ImageSearchTypeConfig(ConfigBase):
encoder: Optional[str] encoder: Optional[str]
model_directory: Optional[Path]
class SearchTypeConfig(ConfigBase): class SearchTypeConfig(ConfigBase):
asymmetric: Optional[AsymmetricConfig] asymmetric: Optional[AsymmetricConfig]
symmetric: Optional[SymmetricConfig]
image: Optional[ImageSearchTypeConfig] image: Optional[ImageSearchTypeConfig]
class ConversationProcessorConfig(ConfigBase): class ConversationProcessorConfig(ConfigBase):

View File

@@ -1,51 +1,78 @@
# Standard Packages # Standard Packages
import pytest import pytest
from pathlib import Path from pathlib import Path
from src import search_type
# Internal Packages # Internal Packages
from src.search_type import asymmetric, image_search from src.search_type import asymmetric, image_search
from src.utils.rawconfig import ContentTypeConfig, ImageSearchConfig, TextSearchConfig from src.utils.rawconfig import AsymmetricConfig, ContentTypeConfig, ImageSearchConfig, ImageSearchTypeConfig, SearchTypeConfig, SymmetricConfig, TextSearchConfig
@pytest.fixture(scope='session') @pytest.fixture(scope='session')
def model_dir(tmp_path_factory): def search_config(tmp_path_factory):
model_dir = tmp_path_factory.mktemp('data') model_dir = tmp_path_factory.mktemp('data')
search_config = SearchTypeConfig()
search_config.asymmetric = SymmetricConfig(
encoder = "sentence-transformers/paraphrase-MiniLM-L6-v2",
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory = model_dir
)
search_config.asymmetric = AsymmetricConfig(
encoder = "sentence-transformers/msmarco-MiniLM-L-6-v3",
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory = model_dir
)
search_config.image = ImageSearchTypeConfig(
encoder = "clip-ViT-B-32",
model_directory = model_dir
)
return search_config
@pytest.fixture(scope='session')
def model_dir(search_config):
model_dir = search_config.asymmetric.model_directory
# Generate Image Embeddings from Test Images # Generate Image Embeddings from Test Images
search_config = ContentTypeConfig() content_config = ContentTypeConfig()
search_config.image = ImageSearchConfig( content_config.image = ImageSearchConfig(
input_directory = 'tests/data', input_directory = 'tests/data',
embeddings_file = model_dir.joinpath('.image_embeddings.pt'), embeddings_file = model_dir.joinpath('.image_embeddings.pt'),
batch_size = 10, batch_size = 10,
use_xmp_metadata = False) use_xmp_metadata = False)
image_search.setup(search_config.image, regenerate=False, verbose=True) image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True)
# Generate Notes Embeddings from Test Notes # Generate Notes Embeddings from Test Notes
search_config.org = TextSearchConfig( content_config.org = TextSearchConfig(
input_files = ['tests/data/main_readme.org', 'tests/data/interface_emacs_readme.org'], input_files = ['tests/data/main_readme.org', 'tests/data/interface_emacs_readme.org'],
input_filter = None, input_filter = None,
compressed_jsonl = model_dir.joinpath('.notes.jsonl.gz'), compressed_jsonl = model_dir.joinpath('.notes.jsonl.gz'),
embeddings_file = model_dir.joinpath('.note_embeddings.pt')) embeddings_file = model_dir.joinpath('.note_embeddings.pt'))
asymmetric.setup(search_config.org, regenerate=False, verbose=True) asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False, verbose=True)
return model_dir return model_dir
@pytest.fixture(scope='session') @pytest.fixture(scope='session')
def search_config(model_dir): def content_config(model_dir):
search_config = ContentTypeConfig() content_config = ContentTypeConfig()
search_config.org = TextSearchConfig( content_config.org = TextSearchConfig(
input_files = ['tests/data/main_readme.org', 'tests/data/interface_emacs_readme.org'], input_files = ['tests/data/main_readme.org', 'tests/data/interface_emacs_readme.org'],
input_filter = None, input_filter = None,
compressed_jsonl = model_dir.joinpath('.notes.jsonl.gz'), compressed_jsonl = model_dir.joinpath('.notes.jsonl.gz'),
embeddings_file = model_dir.joinpath('.note_embeddings.pt')) embeddings_file = model_dir.joinpath('.note_embeddings.pt'))
search_config.image = ImageSearchConfig( content_config.image = ImageSearchConfig(
input_directory = 'tests/data', input_directory = 'tests/data',
embeddings_file = 'tests/data/.image_embeddings.pt', embeddings_file = model_dir.joinpath('.image_embeddings.pt'),
batch_size = 10, batch_size = 10,
use_xmp_metadata = False) use_xmp_metadata = False)
return search_config return content_config

View File

@@ -1,14 +1,15 @@
# Internal Packages # Internal Packages
from src.main import model from src.main import model
from src.search_type import asymmetric from src.search_type import asymmetric
from src.utils.rawconfig import ContentTypeConfig, SearchTypeConfig
# Test # Test
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_setup(search_config): def test_asymmetric_setup(content_config: ContentTypeConfig, search_config: SearchTypeConfig):
# Act # Act
# Regenerate notes embeddings during asymmetric setup # Regenerate notes embeddings during asymmetric setup
notes_model = asymmetric.setup(search_config.org, regenerate=True) notes_model = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=True)
# Assert # Assert
assert len(notes_model.entries) == 10 assert len(notes_model.entries) == 10
@@ -16,9 +17,9 @@ def test_asymmetric_setup(search_config):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_search(search_config): def test_asymmetric_search(content_config: ContentTypeConfig, search_config: SearchTypeConfig):
# Arrange # Arrange
model.notes_search = asymmetric.setup(search_config.org, regenerate=False) model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
query = "How to git install application?" query = "How to git install application?"
# Act # Act

View File

@@ -9,7 +9,7 @@ import pytest
from src.main import app, model, config from src.main import app, model, config
from src.search_type import asymmetric, image_search from src.search_type import asymmetric, image_search
from src.utils.helpers import resolve_absolute_path from src.utils.helpers import resolve_absolute_path
from src.utils.rawconfig import ContentTypeConfig from src.utils.rawconfig import ContentTypeConfig, SearchTypeConfig
# Arrange # Arrange
@@ -18,55 +18,60 @@ client = TestClient(app)
# Test # Test
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_search_with_invalid_search_type(): def test_search_with_invalid_content_type():
# Arrange # Arrange
user_query = "How to call semantic search from Emacs?" user_query = "How to call semantic search from Emacs?"
# Act # Act
response = client.get(f"/search?q={user_query}&t=invalid_search_type") response = client.get(f"/search?q={user_query}&t=invalid_content_type")
# Assert # Assert
assert response.status_code == 422 assert response.status_code == 422
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_search_with_valid_search_type(search_config: ContentTypeConfig): def test_search_with_valid_content_type(content_config: ContentTypeConfig, search_config: SearchTypeConfig):
# Arrange # Arrange
config.content_type = search_config config.content_type = content_config
config.search_type = search_config
# config.content_type.image = search_config.image # config.content_type.image = search_config.image
for search_type in ["notes", "ledger", "music", "image"]: for content_type in ["notes", "ledger", "music", "image"]:
# Act # Act
response = client.get(f"/search?q=random&t={search_type}") response = client.get(f"/search?q=random&t={content_type}")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_regenerate_with_invalid_search_type(): def test_regenerate_with_invalid_content_type():
# Act # Act
response = client.get(f"/regenerate?t=invalid_search_type") response = client.get(f"/regenerate?t=invalid_content_type")
# Assert # Assert
assert response.status_code == 422 assert response.status_code == 422
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_regenerate_with_valid_search_type(search_config: ContentTypeConfig): def test_regenerate_with_valid_content_type(content_config: ContentTypeConfig, search_config: SearchTypeConfig):
# Arrange # Arrange
config.content_type = search_config config.content_type = content_config
for search_type in ["notes", "ledger", "music", "image"]: config.search_type = search_config
for content_type in ["notes", "ledger", "music", "image"]:
# Act # Act
response = client.get(f"/regenerate?t={search_type}") response = client.get(f"/regenerate?t={content_type}")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skip(reason="Flaky test. Search doesn't always return expected image path.") @pytest.mark.skip(reason="Flaky test. Search doesn't always return expected image path.")
def test_image_search(search_config: ContentTypeConfig): def test_image_search(content_config: ContentTypeConfig, search_config: SearchTypeConfig):
# Arrange # Arrange
config.content_type = search_config config.content_type = content_config
model.image_search = image_search.setup(search_config.image, regenerate=False) config.search_type = search_config
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
query_expected_image_pairs = [("brown kitten next to fallen plant", "kitten_park.jpg"), query_expected_image_pairs = [("brown kitten next to fallen plant", "kitten_park.jpg"),
("a horse and dog on a leash", "horse_dog.jpg"), ("a horse and dog on a leash", "horse_dog.jpg"),
("A guinea pig eating grass", "guineapig_grass.jpg")] ("A guinea pig eating grass", "guineapig_grass.jpg")]
@@ -78,16 +83,16 @@ def test_image_search(search_config: ContentTypeConfig):
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
actual_image = Path(response.json()[0]["Entry"]) actual_image = Path(response.json()[0]["Entry"])
expected_image = resolve_absolute_path(search_config.image.input_directory.joinpath(expected_image_name)) expected_image = resolve_absolute_path(content_config.image.input_directory.joinpath(expected_image_name))
# Assert # Assert
assert expected_image == actual_image assert expected_image == actual_image
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_notes_search(search_config: ContentTypeConfig): def test_notes_search(content_config: ContentTypeConfig, search_config: SearchTypeConfig):
# Arrange # Arrange
model.notes_search = asymmetric.setup(search_config.org, regenerate=False) model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
user_query = "How to git install application?" user_query = "How to git install application?"
# Act # Act
@@ -101,9 +106,9 @@ def test_notes_search(search_config: ContentTypeConfig):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_notes_search_with_include_filter(search_config: ContentTypeConfig): def test_notes_search_with_include_filter(content_config: ContentTypeConfig, search_config: SearchTypeConfig):
# Arrange # Arrange
model.notes_search = asymmetric.setup(search_config.org, regenerate=False) model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
user_query = "How to git install application? +Emacs" user_query = "How to git install application? +Emacs"
# Act # Act
@@ -117,9 +122,9 @@ def test_notes_search_with_include_filter(search_config: ContentTypeConfig):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_notes_search_with_exclude_filter(search_config: ContentTypeConfig): def test_notes_search_with_exclude_filter(content_config: ContentTypeConfig, search_config: SearchTypeConfig):
# Arrange # Arrange
model.notes_search = asymmetric.setup(search_config.org, regenerate=False) model.notes_search = asymmetric.setup(content_config.org, search_config.asymmetric, regenerate=False)
user_query = "How to git install application? -clone" user_query = "How to git install application? -clone"
# Act # Act

View File

@@ -5,14 +5,15 @@ import pytest
from src.main import model from src.main import model
from src.search_type import image_search from src.search_type import image_search
from src.utils.helpers import resolve_absolute_path from src.utils.helpers import resolve_absolute_path
from src.utils.rawconfig import ContentTypeConfig, SearchTypeConfig
# Test # Test
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_image_search_setup(search_config): def test_image_search_setup(content_config: ContentTypeConfig, search_config: SearchTypeConfig):
# Act # Act
# Regenerate image search embeddings during image setup # Regenerate image search embeddings during image setup
image_search_model = image_search.setup(search_config.image, regenerate=True) image_search_model = image_search.setup(content_config.image, search_config.image, regenerate=True)
# Assert # Assert
assert len(image_search_model.image_names) == 3 assert len(image_search_model.image_names) == 3
@@ -21,9 +22,9 @@ def test_image_search_setup(search_config):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skip(reason="results inconsistent currently") @pytest.mark.skip(reason="results inconsistent currently")
def test_image_search(search_config): def test_image_search(content_config: ContentTypeConfig, search_config: SearchTypeConfig):
# Arrange # Arrange
model.image_search = image_search.setup(search_config.image, regenerate=False) model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
query_expected_image_pairs = [("brown kitten next to plant", "kitten_park.jpg"), query_expected_image_pairs = [("brown kitten next to plant", "kitten_park.jpg"),
("horse and dog in a farm", "horse_dog.jpg"), ("horse and dog in a farm", "horse_dog.jpg"),
("A guinea pig eating grass", "guineapig_grass.jpg")] ("A guinea pig eating grass", "guineapig_grass.jpg")]
@@ -38,11 +39,11 @@ def test_image_search(search_config):
results = image_search.collate_results( results = image_search.collate_results(
hits, hits,
model.image_search.image_names, model.image_search.image_names,
search_config.image.input_directory, content_config.image.input_directory,
count=1) count=1)
actual_image = results[0]["Entry"] actual_image = results[0]["Entry"]
expected_image = resolve_absolute_path(search_config.image.input_directory.joinpath(expected_image_name)) expected_image = resolve_absolute_path(content_config.image.input_directory.joinpath(expected_image_name))
# Assert # Assert
assert expected_image == actual_image assert expected_image == actual_image