From d5597442f40ce1dae5d03b333320e1a598d25f7f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 30 Sep 2021 02:04:04 -0700 Subject: [PATCH] Modularize Code. Wrap Search, Model Config in Classes. Add Tests Details - Rename method query_* to query in search_types for standardization - Wrapping Config code in classes simplified mocking test config - Reduce args beings passed to a function by passing it as single argument wrapped in a class - Minimize setup in main.py:__main__. Put most of it into functions These functions can be mocked if required in tests later too Setup Flow: CLI_Args|Config_YAML -> (Text|Image)SearchConfig -> (Text|Image)SearchModel --- src/main.py | 128 ++++++++++------------------ src/search_type/asymmetric.py | 18 ++-- src/search_type/image_search.py | 47 ++++++---- src/search_type/symmetric_ledger.py | 29 ++++--- src/tests/test_main.py | 44 ++++++++-- src/utils/config.py | 89 ++++++++++++++----- 6 files changed, 201 insertions(+), 154 deletions(-) diff --git a/src/main.py b/src/main.py index c55462a1..c516c073 100644 --- a/src/main.py +++ b/src/main.py @@ -11,12 +11,12 @@ from fastapi import FastAPI from search_type import asymmetric, symmetric_ledger, image_search from utils.helpers import get_from_dict from utils.cli import cli -from utils.config import SearchType, SearchSettings, SearchModels +from utils.config import SearchType, SearchModels, TextSearchConfig, ImageSearchConfig, SearchConfig # Application Global State model = SearchModels() -search_settings = SearchSettings() +search_config = SearchConfig() app = FastAPI() @@ -29,36 +29,36 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): user_query = q results_count = n - if (t == SearchType.Notes or t == None) and search_settings.notes_search_enabled: + if (t == SearchType.Notes or t == None) and model.notes_search: # query notes - hits = asymmetric.query_notes(user_query, model.notes_search) + hits = asymmetric.query(user_query, model.notes_search) # collate and return results return asymmetric.collate_results(hits, model.notes_search.entries, results_count) - if (t == SearchType.Music or t == None) and search_settings.music_search_enabled: + if (t == SearchType.Music or t == None) and model.music_search: # query music library - hits = asymmetric.query_notes(user_query, model.music_search) + hits = asymmetric.query(user_query, model.music_search) # collate and return results return asymmetric.collate_results(hits, model.music_search.entries, results_count) - if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled: + if (t == SearchType.Ledger or t == None) and model.ledger_search: # query transactions - hits = symmetric_ledger.query_transactions(user_query, model.ledger_search) + hits = symmetric_ledger.query(user_query, model.ledger_search) # collate and return results return symmetric_ledger.collate_results(hits, model.ledger_search.entries, results_count) - if (t == SearchType.Image or t == None) and search_settings.image_search_enabled: + if (t == SearchType.Image or t == None) and model.image_search: # query transactions - hits = image_search.query_images(user_query, model.image_search, args.verbose) + hits = image_search.query(user_query, results_count, model.image_search) # collate and return results return image_search.collate_results( hits, model.image_search.image_names, - image_config['input-directory'], + search_config.image.input_directory, results_count) else: @@ -67,98 +67,58 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): @app.get('/regenerate') def regenerate(t: Optional[SearchType] = None): - if (t == SearchType.Notes or t == None) and search_settings.notes_search_enabled: + if (t == SearchType.Notes or t == None) and search_config.notes: # Extract Entries, Generate Embeddings - models.notes_search = asymmetric.setup( - org_config['input-files'], - org_config['input-filter'], - pathlib.Path(org_config['compressed-jsonl']), - pathlib.Path(org_config['embeddings-file']), - regenerate=True, - verbose=args.verbose) + model.notes_search = asymmetric.setup(search_config.notes, regenerate=True) - if (t == SearchType.Music or t == None) and search_settings.music_search_enabled: + if (t == SearchType.Music or t == None) and search_config.music: # Extract Entries, Generate Song Embeddings - model.music_search = asymmetric.setup( - song_config['input-files'], - song_config['input-filter'], - pathlib.Path(song_config['compressed-jsonl']), - pathlib.Path(song_config['embeddings-file']), - regenerate=True, - verbose=args.verbose) + model.music_search = asymmetric.setup(search_config.music, regenerate=True) - if (t == SearchType.Ledger or t == None) and search_settings.ledger_search_enabled: + if (t == SearchType.Ledger or t == None) and search_config.ledger: # Extract Entries, Generate Embeddings - model.ledger_search = symmetric_ledger.setup( - ledger_config['input-files'], - ledger_config['input-filter'], - pathlib.Path(ledger_config['compressed-jsonl']), - pathlib.Path(ledger_config['embeddings-file']), - regenerate=True, - verbose=args.verbose) + model.ledger_search = symmetric_ledger.setup(search_config.ledger, regenerate=True) - if (t == SearchType.Image or t == None) and search_settings.image_search_enabled: + if (t == SearchType.Image or t == None) and search_config.image: # Extract Images, Generate Embeddings - model.image_search = image_search.setup( - pathlib.Path(image_config['input-directory']), - pathlib.Path(image_config['embeddings-file']), - regenerate=True, - verbose=args.verbose) + model.image_search = image_search.setup(search_config.image, regenerate=True) return {'status': 'ok', 'message': 'regeneration completed'} -if __name__ == '__main__': - args = cli(sys.argv[1:]) +def initialize_search(config, regenerate, verbose): + model = SearchModels() + search_config = SearchConfig() # Initialize Org Notes Search - org_config = get_from_dict(args.config, 'content-type', 'org') - if org_config and ('input-files' in org_config or 'input-filter' in org_config): - search_settings.notes_search_enabled = True - model.notes_search = asymmetric.setup( - org_config['input-files'], - org_config['input-filter'], - pathlib.Path(org_config['compressed-jsonl']), - pathlib.Path(org_config['embeddings-file']), - args.regenerate, - args.verbose) + search_config.notes = TextSearchConfig.create_from_dictionary(config, ('content-type', 'org'), verbose) + if search_config.notes: + model.notes_search = asymmetric.setup(search_config.notes, regenerate=regenerate) # Initialize Org Music Search - song_config = get_from_dict(args.config, 'content-type', 'music') - music_search_enabled = False - if song_config and ('input-files' in song_config or 'input-filter' in song_config): - search_settings.music_search_enabled = True - model.music_search = asymmetric.setup( - song_config['input-files'], - song_config['input-filter'], - pathlib.Path(song_config['compressed-jsonl']), - pathlib.Path(song_config['embeddings-file']), - args.regenerate, - args.verbose) + search_config.music = TextSearchConfig.create_from_dictionary(config, ('content-type', 'music'), verbose) + if search_config.music: + model.music_search = asymmetric.setup(search_config.music, regenerate=regenerate) # Initialize Ledger Search - ledger_config = get_from_dict(args.config, 'content-type', 'ledger') - if ledger_config and ('input-files' in ledger_config or 'input-filter' in ledger_config): - search_settings.ledger_search_enabled = True - model.ledger_search = symmetric_ledger.setup( - ledger_config['input-files'], - ledger_config['input-filter'], - pathlib.Path(ledger_config['compressed-jsonl']), - pathlib.Path(ledger_config['embeddings-file']), - args.regenerate, - args.verbose) + search_config.ledger = TextSearchConfig.create_from_dictionary(config, ('content-type', 'ledger'), verbose) + if search_config.ledger: + model.ledger_search = symmetric_ledger.setup(search_config.ledger, regenerate=regenerate) # Initialize Image Search - image_config = get_from_dict(args.config, 'content-type', 'image') - if image_config and 'input-directory' in image_config: - search_settings.image_search_enabled = True - model.image_search = image_search.setup( - pathlib.Path(image_config['input-directory']), - pathlib.Path(image_config['embeddings-file']), - batch_size=image_config['batch-size'], - regenerate=args.regenerate, - use_xmp_metadata={'yes': True, 'no': False}[image_config['use-xmp-metadata']], - verbose=args.verbose) + search_config.image = ImageSearchConfig.create_from_dictionary(config, ('content-type', 'image'), verbose) + if search_config.image: + model.image_search = image_search.setup(search_config.image, regenerate=regenerate) + + return model, search_config + + +if __name__ == '__main__': + # Load config from CLI + args = cli(sys.argv[1:]) + + # Initialize Search from Config + model, search_config = initialize_search(args.config, args.regenerate, args.verbose) # Start Application Server uvicorn.run(app) diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index fd5c688e..3f7785a9 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -17,7 +17,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util # Internal Packages from utils.helpers import get_absolute_path, resolve_absolute_path from processor.org_mode.org_to_jsonl import org_to_jsonl -from utils.config import AsymmetricSearchModel +from utils.config import TextSearchModel, TextSearchConfig def initialize_model(): @@ -66,7 +66,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v return corpus_embeddings -def query_notes(raw_query: str, model: AsymmetricSearchModel): +def query(raw_query: str, model: TextSearchModel): "Search all notes for entries that answer the query" # Separate natural query from explicit required, blocked words filters query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) @@ -151,21 +151,21 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=False, verbose=False): +def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model() # Map notes in Org-Mode files to (compressed) JSONL formatted file - if not resolve_absolute_path(compressed_jsonl).exists() or regenerate: - org_to_jsonl(input_files, input_filter, compressed_jsonl, verbose) + if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate: + org_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, config.verbose) # Extract Entries - entries = extract_entries(compressed_jsonl, verbose) + entries = extract_entries(config.compressed_jsonl, config.verbose) # Compute or Load Embeddings - corpus_embeddings = compute_embeddings(entries, bi_encoder, embeddings, regenerate=regenerate, verbose=verbose) + corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=config.verbose) - return AsymmetricSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k) + return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=config.verbose) if __name__ == '__main__': @@ -191,7 +191,7 @@ if __name__ == '__main__': exit(0) # query notes - hits = query_notes(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k) + hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k) # render results render_results(hits, entries, count=args.results_count) diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index 82a448bb..2091aa30 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -12,6 +12,8 @@ import torch # Internal Packages from utils.helpers import get_absolute_path, resolve_absolute_path import utils.exiftool as exiftool +from utils.config import ImageSearchModel, ImageSearchConfig + def initialize_model(): # Initialize Model @@ -93,30 +95,31 @@ def extract_metadata(image_name, verbose=0): return image_processed_metadata -def query_images(query, image_embeddings, image_metadata_embeddings, model, count=3, verbose=0): +def query(raw_query, count, model: ImageSearchModel): # Set query to image content if query is a filepath - if pathlib.Path(query).is_file(): - query_imagepath = resolve_absolute_path(pathlib.Path(query), strict=True) + if pathlib.Path(raw_query).is_file(): + query_imagepath = resolve_absolute_path(pathlib.Path(raw_query), strict=True) query = copy.deepcopy(Image.open(query_imagepath)) - if verbose > 0: + if model.verbose > 0: print(f"Find Images similar to Image at {query_imagepath}") else: - if verbose > 0: + query = raw_query + if model.verbose > 0: print(f"Find Images by Text: {query}") # Now we encode the query (which can either be an image or a text string) - query_embedding = model.encode([query], convert_to_tensor=True, show_progress_bar=False) + query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False) # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. image_hits = {result['corpus_id']: result['score'] for result - in util.semantic_search(query_embedding, image_embeddings, top_k=count)[0]} + in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]} # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. - if image_metadata_embeddings: + if model.image_metadata_embeddings: metadata_hits = {result['corpus_id']: result['score'] for result - in util.semantic_search(query_embedding, image_metadata_embeddings, top_k=count)[0]} + in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]} # Sum metadata, image scores of the highest ranked images for corpus_id, score in metadata_hits.items(): @@ -150,20 +153,30 @@ def collate_results(hits, image_names, image_directory, count=5): in hits[0:count]] -def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0): +def setup(config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel: # Initialize Model model = initialize_model() # Extract Entries - image_directory = resolve_absolute_path(image_directory, strict=True) - image_names = extract_entries(image_directory, verbose) + image_directory = resolve_absolute_path(config.input_directory, strict=True) + image_names = extract_entries(config.input_directory, config.verbose) # Compute or Load Embeddings - embeddings_file = resolve_absolute_path(embeddings_file) - image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file, - batch_size=batch_size, regenerate=regenerate, use_xmp_metadata=use_xmp_metadata, verbose=verbose) + embeddings_file = resolve_absolute_path(config.embeddings_file) + image_embeddings, image_metadata_embeddings = compute_embeddings( + image_names, + model, + embeddings_file, + batch_size=config.batch_size, + regenerate=regenerate, + use_xmp_metadata=config.use_xmp_metadata, + verbose=config.verbose) - return image_names, image_embeddings, image_metadata_embeddings, model + return ImageSearchModel(image_names, + image_embeddings, + image_metadata_embeddings, + model, + config.verbose) if __name__ == '__main__': @@ -187,7 +200,7 @@ if __name__ == '__main__': exit(0) # query images - hits = query_images(user_query, image_embeddings, image_metadata_embeddings, model, args.results_count, args.verbose) + hits = query(user_query, image_embeddings, image_metadata_embeddings, model, args.results_count, args.verbose) # render results render_results(hits, image_names, args.image_directory, count=args.results_count) diff --git a/src/search_type/symmetric_ledger.py b/src/search_type/symmetric_ledger.py index a5e0c04c..41fe040d 100644 --- a/src/search_type/symmetric_ledger.py +++ b/src/search_type/symmetric_ledger.py @@ -15,6 +15,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util # Internal Packages from utils.helpers import get_absolute_path, resolve_absolute_path from processor.ledger.beancount_to_jsonl import beancount_to_jsonl +from utils.config import TextSearchModel, TextSearchConfig def initialize_model(): @@ -59,7 +60,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, v return corpus_embeddings -def query_transactions(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k=100): +def query(raw_query, model: TextSearchModel): "Search all notes for entries that answer the query" # Separate natural query from explicit required, blocked words filters query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) @@ -67,20 +68,20 @@ def query_transactions(raw_query, corpus_embeddings, entries, bi_encoder, cross_ blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) # Encode the query using the bi-encoder - question_embedding = bi_encoder.encode(query, convert_to_tensor=True) + question_embedding = model.bi_encoder.encode(query, convert_to_tensor=True) # Find relevant entries for the query - hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) + hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k) hits = hits[0] # Get the hits for the first query # Filter results using explicit filters - hits = explicit_filter(hits, entries, required_words, blocked_words) + hits = explicit_filter(hits, model.entries, required_words, blocked_words) if hits is None or len(hits) == 0: return hits # Score all retrieved entries using the cross-encoder - cross_inp = [[query, entries[hit['corpus_id']]] for hit in hits] - cross_scores = cross_encoder.predict(cross_inp) + cross_inp = [[query, model.entries[hit['corpus_id']]] for hit in hits] + cross_scores = model.cross_encoder.predict(cross_inp) # Store cross-encoder scores in results dictionary for ranking for idx in range(len(cross_scores)): @@ -142,21 +143,21 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=False, verbose=False): +def setup(config: TextSearchConfig, regenerate: bool) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model() # Map notes in Org-Mode files to (compressed) JSONL formatted file - if not resolve_absolute_path(compressed_jsonl).exists() or regenerate: - beancount_to_jsonl(input_files, input_filter, compressed_jsonl, verbose) + if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate: + beancount_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, config.verbose) # Extract Entries - entries = extract_entries(compressed_jsonl, verbose) + entries = extract_entries(config.compressed_jsonl, config.verbose) # Compute or Load Embeddings - corpus_embeddings = compute_embeddings(entries, bi_encoder, embeddings, regenerate=regenerate, verbose=verbose) + corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, verbose=config.verbose) - return entries, corpus_embeddings, bi_encoder, cross_encoder, top_k + return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=config.verbose) if __name__ == '__main__': @@ -181,8 +182,8 @@ if __name__ == '__main__': if user_query == "exit": exit(0) - # query notes - hits = query_transactions(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k) + # query + hits = query(user_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k) # render results render_results(hits, entries, count=args.results_count) diff --git a/src/tests/test_main.py b/src/tests/test_main.py index 0d2a8bee..dcda9f56 100644 --- a/src/tests/test_main.py +++ b/src/tests/test_main.py @@ -6,8 +6,9 @@ import pytest from fastapi.testclient import TestClient # Internal Packages -from main import app, search_settings, model +from main import app, search_config, model from search_type import asymmetric +from utils.config import SearchConfig, TextSearchConfig # Arrange @@ -60,14 +61,17 @@ def test_regenerate_with_valid_search_type(): # ---------------------------------------------------------------------------------------------------- def test_notes_search(): # Arrange - input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')] - input_filter = None - compressed_jsonl = Path('tests/data/.test.jsonl.gz') - embeddings = Path('tests/data/.test_embeddings.pt') + search_config = SearchConfig() + search_config.notes = TextSearchConfig( + input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')], + input_filter = None, + compressed_jsonl = Path('tests/data/.test.jsonl.gz'), + embeddings_file = Path('tests/data/.test_embeddings.pt'), + verbose = 0) # Act # Regenerate embeddings during asymmetric setup - notes_model = asymmetric.setup(input_files, input_filter, compressed_jsonl, embeddings, regenerate=True, verbose=0) + notes_model = asymmetric.setup(search_config.notes, regenerate=True) # Assert assert len(notes_model.entries) == 10 @@ -75,7 +79,6 @@ def test_notes_search(): # Arrange model.notes_search = notes_model - search_settings.notes_search_enabled = True user_query = "How to call semantic search from Emacs?" # Act @@ -88,3 +91,30 @@ def test_notes_search(): assert "Semantic Search via Emacs" in search_result +# ---------------------------------------------------------------------------------------------------- +def test_notes_regenerate(): + # Arrange + search_config = SearchConfig() + search_config.notes = TextSearchConfig( + input_files = [Path('tests/data/main_readme.org'), Path('tests/data/interface_emacs_readme.org')], + input_filter = None, + compressed_jsonl = Path('tests/data/.test.jsonl.gz'), + embeddings_file = Path('tests/data/.test_embeddings.pt'), + verbose = 0) + + # Act + # Regenerate embeddings during asymmetric setup + notes_model = asymmetric.setup(search_config.notes, regenerate=True) + + # Assert + assert len(notes_model.entries) == 10 + assert len(notes_model.corpus_embeddings) == 10 + + # Arrange + model.notes_search = notes_model + + # Act + response = client.get(f"/regenerate?t=notes") + + # Assert + assert response.status_code == 200 diff --git a/src/utils/config.py b/src/utils/config.py index 0f0bf960..a0dc9244 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -1,6 +1,10 @@ # System Packages from enum import Enum from dataclasses import dataclass +from pathlib import Path + +# Internal Packages +from utils.helpers import get_from_dict class SearchType(str, Enum): @@ -10,43 +14,82 @@ class SearchType(str, Enum): Image = "image" -@dataclass -class SearchSettings(): - notes_search_enabled: bool = False - ledger_search_enabled: bool = False - music_search_enabled: bool = False - image_search_enabled: bool = False - - -class AsymmetricSearchModel(): - def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k): +class TextSearchModel(): + def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder self.cross_encoder = cross_encoder self.top_k = top_k - - -class LedgerSearchModel(): - def __init__(self, transactions, transaction_embeddings, symmetric_encoder, symmetric_cross_encoder, top_k): - self.transactions = transactions - self.transaction_embeddings = transaction_embeddings - self.symmetric_encoder = symmetric_encoder - self.symmetric_cross_encoder = symmetric_cross_encoder - self.top_k = top_k + self.verbose = verbose class ImageSearchModel(): - def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder): + def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder, verbose): + self.image_encoder = image_encoder self.image_names = image_names self.image_embeddings = image_embeddings self.image_metadata_embeddings = image_metadata_embeddings self.image_encoder = image_encoder + self.verbose = verbose @dataclass class SearchModels(): - notes_search: AsymmetricSearchModel = None - ledger_search: LedgerSearchModel = None - music_search: AsymmetricSearchModel = None + notes_search: TextSearchModel = None + ledger_search: TextSearchModel = None + music_search: TextSearchModel = None image_search: ImageSearchModel = None + + +class TextSearchConfig(): + def __init__(self, input_files, input_filter, compressed_jsonl, embeddings_file, verbose): + self.input_files = input_files + self.input_filter = input_filter + self.compressed_jsonl = Path(compressed_jsonl) + self.embeddings_file = Path(embeddings_file) + self.verbose = verbose + + + def create_from_dictionary(config, key_tree, verbose): + text_config = get_from_dict(config, *key_tree) + search_enabled = text_config and ('input-files' in text_config or 'input-filter' in text_config) + if not search_enabled: + return None + + return TextSearchConfig( + input_files = text_config['input-files'], + input_filter = text_config['input-filter'], + compressed_jsonl = Path(text_config['compressed-jsonl']), + embeddings_file = Path(text_config['embeddings-file']), + verbose = verbose) + + +class ImageSearchConfig(): + def __init__(self, input_directory, embeddings_file, batch_size, use_xmp_metadata, verbose): + self.input_directory = input_directory + self.embeddings_file = Path(embeddings_file) + self.batch_size = batch_size + self.use_xmp_metadata = use_xmp_metadata + self.verbose = verbose + + def create_from_dictionary(config, key_tree, verbose): + image_config = get_from_dict(config, *key_tree) + search_enabled = image_config and 'input-directory' in image_config + if not search_enabled: + return None + + return ImageSearchConfig( + input_directory = Path(image_config['input-directory']), + embeddings_file = Path(image_config['embeddings-file']), + batch_size = image_config['batch-size'], + use_xmp_metadata = {'yes': True, 'no': False}[image_config['use-xmp-metadata']], + verbose = verbose) + + +@dataclass +class SearchConfig(): + notes: TextSearchConfig = None + ledger: TextSearchConfig = None + music: TextSearchConfig = None + image: ImageSearchConfig = None