From bdc1b9f2bb2df615dfa178c86de6a39cba7380d3 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 20 Jul 2022 02:58:43 +0400 Subject: [PATCH] Resolve edge case errors in encoding image metadata - Handle case where current image batch smaller than batch_size - Handle case where no XMP metadata for current image - return empty strings in such a scenario instead of ". " --- src/search_type/image_search.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index 896b9118..6efcea7e 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -65,7 +65,10 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5 image_embeddings = [] for index in trange(0, len(image_names), batch_size): images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]] - image_embeddings += encoder.encode(images, convert_to_tensor=True, batch_size=batch_size) + image_embeddings += encoder.encode( + images, + convert_to_tensor=True, + batch_size=min(len(images), batch_size)) torch.save(image_embeddings, embeddings_file) if verbose > 0: print(f"Saved computed embeddings to {embeddings_file}") @@ -88,7 +91,10 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz for index in trange(0, len(image_names), batch_size): image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]] try: - image_metadata_embeddings += encoder.encode(image_metadata, convert_to_tensor=True, batch_size=batch_size) + image_metadata_embeddings += encoder.encode( + image_metadata, + convert_to_tensor=True, + batch_size=min(len(image_metadata), batch_size)) except RuntimeError as e: print(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}") continue @@ -103,9 +109,14 @@ def extract_metadata(image_name, verbose=0): with exiftool.ExifTool() as et: image_metadata = et.get_tags(["XMP:Subject", "XMP:Description"], str(image_name)) image_metadata_subjects = set([subject.split(":")[1] for subject in image_metadata.get("XMP:Subject", "") if ":" in subject]) - image_processed_metadata = image_metadata.get("XMP:Description", "") + ". " + ", ".join(image_metadata_subjects) + + image_processed_metadata = image_metadata.get("XMP:Description", "") + if len(image_metadata_subjects) > 0: + image_processed_metadata += ". " + ", ".join(image_metadata_subjects) + if verbose > 2: print(f"{image_name}:\t{image_processed_metadata}") + return image_processed_metadata