From 6c9ffdba57021d548db700f593d722e529dbab0c Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 20 Jul 2022 02:54:03 +0400 Subject: [PATCH] Allow indexing multiple image directories for image search --- config/sample_config.yml | 2 +- src/search_type/image_search.py | 14 ++++++++------ src/utils/rawconfig.py | 2 +- tests/conftest.py | 4 ++-- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/config/sample_config.yml b/config/sample_config.yml index 5c8a9f71..7f5809c1 100644 --- a/config/sample_config.yml +++ b/config/sample_config.yml @@ -15,7 +15,7 @@ content-type: embeddings-file: /data/embeddings/transaction_embeddings.pt image: - input-directory: "/data/images/" + input-directories: ["/data/images/"] embeddings-file: "/data/embeddings/image_embeddings.pt" batch-size: 50 use-xmp-metadata: true diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index d97754c3..adb33a10 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -30,10 +30,12 @@ def initialize_model(search_config: ImageSearchConfig): return encoder -def extract_entries(image_directory, verbose=0): - image_directory = resolve_absolute_path(image_directory, strict=True) - image_names = list(image_directory.glob('*.jpg')) - image_names.extend(list(image_directory.glob('*.jpeg'))) +def extract_entries(image_directories, verbose=0): + image_names = [] + for image_directory in image_directories: + image_directory = resolve_absolute_path(image_directory, strict=True) + image_names = list(image_directory.glob('*.jpg')) + image_names.extend(list(image_directory.glob('*.jpeg'))) if verbose > 0: print(f'Found {len(image_names)} images in {image_directory}') @@ -197,8 +199,8 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera encoder = initialize_model(search_config) # Extract Entries - image_directory = resolve_absolute_path(config.input_directory, strict=True) - image_names = extract_entries(image_directory, verbose) + image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories] + image_names = extract_entries(image_directories, verbose) # Compute or Load Embeddings embeddings_file = resolve_absolute_path(config.embeddings_file) diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 4a88eb4e..4a8749a8 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -22,7 +22,7 @@ class TextContentConfig(ConfigBase): class ImageContentConfig(ConfigBase): use_xmp_metadata: Optional[bool] batch_size: Optional[int] - input_directory: Optional[Path] + input_directories: Optional[List[Path]] input_filter: Optional[str] embeddings_file: Optional[Path] diff --git a/tests/conftest.py b/tests/conftest.py index 34da236e..2471f8e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,7 +41,7 @@ def model_dir(search_config): # Generate Image Embeddings from Test Images content_config = ContentConfig() content_config.image = ImageContentConfig( - input_directory = 'tests/data/images', + input_directories = ['tests/data/images'], embeddings_file = model_dir.joinpath('image_embeddings.pt'), batch_size = 10, use_xmp_metadata = False) @@ -70,7 +70,7 @@ def content_config(model_dir): embeddings_file = model_dir.joinpath('note_embeddings.pt')) content_config.image = ImageContentConfig( - input_directory = 'tests/data/images', + input_directories = ['tests/data/images'], embeddings_file = model_dir.joinpath('image_embeddings.pt'), batch_size = 10, use_xmp_metadata = False)