diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index 181be13f..e7d049ce 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -61,8 +61,10 @@ def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regen print(f"Saved computed embeddings to {embeddings_file}") if image_metadata_embeddings is None: - image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names], - image_metadata_embeddings = model.encode(image_metadata, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True) + image_metadata_embeddings = [] + for index in trange(0, len(image_names), batch_size): + image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]] + image_metadata_embeddings += model.encode(image_metadata, convert_to_tensor=True, batch_size=batch_size) torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata") if verbose > 0: print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")