diff --git a/src/main.py b/src/main.py index 6fa89d6e..75f793a8 100644 --- a/src/main.py +++ b/src/main.py @@ -293,14 +293,14 @@ def run(): global config_file config_file = args.config_file - # Store the verbose flag - global verbose - verbose = args.verbose - # Store the raw config data. global config config = args.config + # Store the verbose flag + global verbose + verbose = args.verbose + # Set device to GPU if available device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index 8d10a51f..f6187e54 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -1,5 +1,6 @@ # Standard Packages import argparse +import glob import pathlib import copy import shutil @@ -11,7 +12,7 @@ from tqdm import trange import torch # Internal Packages -from src.utils.helpers import resolve_absolute_path, load_model +from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model import src.utils.exiftool as exiftool from src.utils.config import ImageSearchModel from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig @@ -213,13 +214,19 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera encoder = initialize_model(search_config) # Extract Entries - image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories] - image_names = extract_entries(image_directories, verbose) + absolute_image_files, filtered_image_files = set(), set() + if config.input_directories: + image_directories = [resolve_absolute_path(directory, strict=True) for directory in config.input_directories] + absolute_image_files = set(extract_entries(image_directories, verbose)) + if config.input_filter: + filtered_image_files = set(glob.glob(get_absolute_path(config.input_filter))) + + all_image_files = sorted(list(absolute_image_files | filtered_image_files)) # Compute or Load Embeddings embeddings_file = resolve_absolute_path(config.embeddings_file) image_embeddings, image_metadata_embeddings = compute_embeddings( - image_names, + all_image_files, encoder, embeddings_file, batch_size=config.batch_size, @@ -227,7 +234,7 @@ def setup(config: ImageContentConfig, search_config: ImageSearchConfig, regenera use_xmp_metadata=config.use_xmp_metadata, verbose=verbose) - return ImageSearchModel(image_names, + return ImageSearchModel(all_image_files, image_embeddings, image_metadata_embeddings, encoder,