Make Using XMP Metadata to Enhance Image Search Optional, Configurable

- Break the compute embeddings method into separate methods:
  compute_image_embeddings and compute_metadata_embeddings

- If image_metadata_embeddings isn't defined, do not use it to enhance
  search results. Given image_metadata_embeddings wouldn't be defined
  if use_xmp_metadata is False, we can avoid unnecessary addition of
  args to query method
This commit is contained in:
Debanjum Singh Solanky
2021-09-16 12:01:05 -07:00
parent a4a23d7a72
commit 169ddcc8c6
4 changed files with 39 additions and 23 deletions

View File

@@ -12,6 +12,7 @@ content-type:
image: image:
embeddings-file: '.image_embeddings.pt' embeddings-file: '.image_embeddings.pt'
batch-size: 50 batch-size: 50
use-xmp-metadata: 'no'
search-type: search-type:
asymmetric: asymmetric:

View File

@@ -189,6 +189,7 @@ if __name__ == '__main__':
pathlib.Path(image_config['embeddings-file']), pathlib.Path(image_config['embeddings-file']),
batch_size=image_config['batch-size'], batch_size=image_config['batch-size'],
regenerate=args.regenerate, regenerate=args.regenerate,
use_xmp_metadata={'yes': True, 'no': False}[image_config['use-xmp-metadata']],
verbose=args.verbose) verbose=args.verbose)
# Start Application Server # Start Application Server

View File

@@ -30,10 +30,17 @@ def extract_entries(image_directory, verbose=0):
return image_names return image_names
def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, verbose=0): def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings" "Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
image_embeddings = compute_image_embeddings(image_names, model, embeddings_file, batch_size, regenerate, verbose)
image_metadata_embeddings = compute_metadata_embeddings(image_names, model, embeddings_file, batch_size, use_xmp_metadata, regenerate, verbose)
return image_embeddings, image_metadata_embeddings
def compute_image_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, verbose=0):
image_embeddings = None image_embeddings = None
image_metadata_embeddings = None
# Load pre-computed image embeddings from file if exists # Load pre-computed image embeddings from file if exists
if resolve_absolute_path(embeddings_file).exists() and not regenerate: if resolve_absolute_path(embeddings_file).exists() and not regenerate:
@@ -41,16 +48,7 @@ def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regen
if verbose > 0: if verbose > 0:
print(f"Loaded pre-computed embeddings from {embeddings_file}") print(f"Loaded pre-computed embeddings from {embeddings_file}")
# load pre-computed image metadata embedding file if exists # Else compute the image embeddings from scratch, which can take a while
if resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate:
image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata")
if verbose > 0:
print(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
if image_embeddings is None or image_metadata_embeddings is None: # Else compute the image_embeddings from scratch, which can take a while
if verbose > 0:
print(f"Loading the {len(image_names)} images into memory")
if image_embeddings is None: if image_embeddings is None:
image_embeddings = [] image_embeddings = []
for index in trange(0, len(image_names), batch_size): for index in trange(0, len(image_names), batch_size):
@@ -60,7 +58,20 @@ def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regen
if verbose > 0: if verbose > 0:
print(f"Saved computed embeddings to {embeddings_file}") print(f"Saved computed embeddings to {embeddings_file}")
if image_metadata_embeddings is None: return image_embeddings
def compute_metadata_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0):
image_metadata_embeddings = None
# Load pre-computed image metadata embedding file if exists
if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate:
image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata")
if verbose > 0:
print(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
# Else compute the image metadata embeddings from scratch, which can take a while
if use_xmp_metadata and image_metadata_embeddings is None:
image_metadata_embeddings = [] image_metadata_embeddings = []
for index in trange(0, len(image_names), batch_size): for index in trange(0, len(image_names), batch_size):
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]] image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]]
@@ -69,7 +80,7 @@ def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regen
if verbose > 0: if verbose > 0:
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata") print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
return image_embeddings, image_metadata_embeddings return image_metadata_embeddings
def extract_metadata(image_name, verbose=0): def extract_metadata(image_name, verbose=0):
@@ -102,6 +113,7 @@ def query_images(query, image_embeddings, image_metadata_embeddings, model, coun
in util.semantic_search(query_embedding, image_embeddings, top_k=count)[0]} in util.semantic_search(query_embedding, image_embeddings, top_k=count)[0]}
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if image_metadata_embeddings:
metadata_hits = {result['corpus_id']: result['score'] metadata_hits = {result['corpus_id']: result['score']
for result for result
in util.semantic_search(query_embedding, image_metadata_embeddings, top_k=count)[0]} in util.semantic_search(query_embedding, image_metadata_embeddings, top_k=count)[0]}
@@ -138,7 +150,7 @@ def collate_results(hits, image_names, image_directory, count=5):
in hits[0:count]] in hits[0:count]]
def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, verbose=0): def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0):
# Initialize Model # Initialize Model
model = initialize_model() model = initialize_model()
@@ -148,7 +160,8 @@ def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, ver
# Compute or Load Embeddings # Compute or Load Embeddings
embeddings_file = resolve_absolute_path(embeddings_file) embeddings_file = resolve_absolute_path(embeddings_file)
image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file, batch_size=batch_size, regenerate=regenerate, verbose=verbose) image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file,
batch_size=batch_size, regenerate=regenerate, use_xmp_metadata=use_xmp_metadata, verbose=verbose)
return image_names, image_embeddings, image_metadata_embeddings, model return image_names, image_embeddings, image_metadata_embeddings, model

View File

@@ -57,7 +57,8 @@ default_config = {
'image': 'image':
{ {
'embeddings-file': '.image_embeddings.pt', 'embeddings-file': '.image_embeddings.pt',
'batch-size': 50 'batch-size': 50,
'use-xmp-metadata': 'no'
}, },
'music': 'music':
{ {