diff --git a/sample_config.yml b/sample_config.yml index 68024f45..41c4bdf5 100644 --- a/sample_config.yml +++ b/sample_config.yml @@ -12,6 +12,7 @@ content-type: image: embeddings-file: '.image_embeddings.pt' batch-size: 50 + use-xmp-metadata: 'no' search-type: asymmetric: diff --git a/src/main.py b/src/main.py index e13a0cc7..d0d78bd0 100644 --- a/src/main.py +++ b/src/main.py @@ -189,6 +189,7 @@ if __name__ == '__main__': pathlib.Path(image_config['embeddings-file']), batch_size=image_config['batch-size'], regenerate=args.regenerate, + use_xmp_metadata={'yes': True, 'no': False}[image_config['use-xmp-metadata']], verbose=args.verbose) # Start Application Server diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index e7d049ce..82a448bb 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -30,10 +30,17 @@ def extract_entries(image_directory, verbose=0): return image_names -def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, verbose=0): +def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" + + image_embeddings = compute_image_embeddings(image_names, model, embeddings_file, batch_size, regenerate, verbose) + image_metadata_embeddings = compute_metadata_embeddings(image_names, model, embeddings_file, batch_size, use_xmp_metadata, regenerate, verbose) + + return image_embeddings, image_metadata_embeddings + + +def compute_image_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, verbose=0): image_embeddings = None - image_metadata_embeddings = None # Load pre-computed image embeddings from file if exists if resolve_absolute_path(embeddings_file).exists() and not regenerate: @@ -41,16 +48,7 @@ def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regen if verbose > 0: print(f"Loaded pre-computed embeddings from {embeddings_file}") - # load pre-computed image metadata embedding file if exists - if resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate: - image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata") - if verbose > 0: - print(f"Loaded pre-computed embeddings from {embeddings_file}_metadata") - - if image_embeddings is None or image_metadata_embeddings is None: # Else compute the image_embeddings from scratch, which can take a while - if verbose > 0: - print(f"Loading the {len(image_names)} images into memory") - + # Else compute the image embeddings from scratch, which can take a while if image_embeddings is None: image_embeddings = [] for index in trange(0, len(image_names), batch_size): @@ -60,7 +58,20 @@ def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regen if verbose > 0: print(f"Saved computed embeddings to {embeddings_file}") - if image_metadata_embeddings is None: + return image_embeddings + + +def compute_metadata_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0): + image_metadata_embeddings = None + + # Load pre-computed image metadata embedding file if exists + if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate: + image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata") + if verbose > 0: + print(f"Loaded pre-computed embeddings from {embeddings_file}_metadata") + + # Else compute the image metadata embeddings from scratch, which can take a while + if use_xmp_metadata and image_metadata_embeddings is None: image_metadata_embeddings = [] for index in trange(0, len(image_names), batch_size): image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]] @@ -69,7 +80,7 @@ def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regen if verbose > 0: print(f"Saved computed metadata embeddings to {embeddings_file}_metadata") - return image_embeddings, image_metadata_embeddings + return image_metadata_embeddings def extract_metadata(image_name, verbose=0): @@ -102,13 +113,14 @@ def query_images(query, image_embeddings, image_metadata_embeddings, model, coun in util.semantic_search(query_embedding, image_embeddings, top_k=count)[0]} # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. - metadata_hits = {result['corpus_id']: result['score'] - for result - in util.semantic_search(query_embedding, image_metadata_embeddings, top_k=count)[0]} + if image_metadata_embeddings: + metadata_hits = {result['corpus_id']: result['score'] + for result + in util.semantic_search(query_embedding, image_metadata_embeddings, top_k=count)[0]} - # Sum metadata, image scores of the highest ranked images - for corpus_id, score in metadata_hits.items(): - image_hits[corpus_id] = image_hits.get(corpus_id, 0) + score + # Sum metadata, image scores of the highest ranked images + for corpus_id, score in metadata_hits.items(): + image_hits[corpus_id] = image_hits.get(corpus_id, 0) + score # Reformat results in original form from sentence transformer semantic_search() hits = [{'corpus_id': corpus_id, 'score': score} for corpus_id, score in image_hits.items()] @@ -138,7 +150,7 @@ def collate_results(hits, image_names, image_directory, count=5): in hits[0:count]] -def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, verbose=0): +def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0): # Initialize Model model = initialize_model() @@ -148,7 +160,8 @@ def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, ver # Compute or Load Embeddings embeddings_file = resolve_absolute_path(embeddings_file) - image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file, batch_size=batch_size, regenerate=regenerate, verbose=verbose) + image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file, + batch_size=batch_size, regenerate=regenerate, use_xmp_metadata=use_xmp_metadata, verbose=verbose) return image_names, image_embeddings, image_metadata_embeddings, model diff --git a/src/utils/cli.py b/src/utils/cli.py index 8d46bd02..25e2578a 100644 --- a/src/utils/cli.py +++ b/src/utils/cli.py @@ -57,7 +57,8 @@ default_config = { 'image': { 'embeddings-file': '.image_embeddings.pt', - 'batch-size': 50 + 'batch-size': 50, + 'use-xmp-metadata': 'no' }, 'music': {