diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index fe936559..8c1eb852 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -6,6 +6,7 @@ import copy # External Packages from sentence_transformers import SentenceTransformer, util from PIL import Image +from tqdm import trange import torch # Internal Packages @@ -50,23 +51,20 @@ def compute_embeddings(image_names, model, embeddings_file, regenerate=False, ve if verbose > 0: print(f"Loading the {len(image_names)} images into memory") + batch_size = 50 if image_embeddings is None: - image_embeddings = model.encode( - [Image.open(image_name).copy() for image_name in image_names], - batch_size=128, convert_to_tensor=True, show_progress_bar=verbose > 0) - + image_embeddings = [] + for index in trange(0, len(image_names), batch_size): + images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]] + image_embeddings += model.encode(images, convert_to_tensor=True, batch_size=batch_size) torch.save(image_embeddings, embeddings_file) - if verbose > 0: print(f"Saved computed embeddings to {embeddings_file}") if image_metadata_embeddings is None: - image_metadata_embeddings = model.encode( - [extract_metadata(image_name, verbose) for image_name in image_names], - batch_size=128, convert_to_tensor=True, show_progress_bar=verbose > 0) - + image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names], + image_metadata_embeddings = model.encode(image_metadata, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True) torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata") - if verbose > 0: print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")