mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Make Using XMP Metadata to Enhance Image Search Optional, Configurable
- Break the compute embeddings method into separate methods: compute_image_embeddings and compute_metadata_embeddings - If image_metadata_embeddings isn't defined, do not use it to enhance search results. Given image_metadata_embeddings wouldn't be defined if use_xmp_metadata is False, we can avoid unnecessary addition of args to query method
This commit is contained in:
@@ -12,6 +12,7 @@ content-type:
|
|||||||
image:
|
image:
|
||||||
embeddings-file: '.image_embeddings.pt'
|
embeddings-file: '.image_embeddings.pt'
|
||||||
batch-size: 50
|
batch-size: 50
|
||||||
|
use-xmp-metadata: 'no'
|
||||||
|
|
||||||
search-type:
|
search-type:
|
||||||
asymmetric:
|
asymmetric:
|
||||||
|
|||||||
@@ -189,6 +189,7 @@ if __name__ == '__main__':
|
|||||||
pathlib.Path(image_config['embeddings-file']),
|
pathlib.Path(image_config['embeddings-file']),
|
||||||
batch_size=image_config['batch-size'],
|
batch_size=image_config['batch-size'],
|
||||||
regenerate=args.regenerate,
|
regenerate=args.regenerate,
|
||||||
|
use_xmp_metadata={'yes': True, 'no': False}[image_config['use-xmp-metadata']],
|
||||||
verbose=args.verbose)
|
verbose=args.verbose)
|
||||||
|
|
||||||
# Start Application Server
|
# Start Application Server
|
||||||
|
|||||||
@@ -30,10 +30,17 @@ def extract_entries(image_directory, verbose=0):
|
|||||||
return image_names
|
return image_names
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, verbose=0):
|
def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0):
|
||||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||||
|
|
||||||
|
image_embeddings = compute_image_embeddings(image_names, model, embeddings_file, batch_size, regenerate, verbose)
|
||||||
|
image_metadata_embeddings = compute_metadata_embeddings(image_names, model, embeddings_file, batch_size, use_xmp_metadata, regenerate, verbose)
|
||||||
|
|
||||||
|
return image_embeddings, image_metadata_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_image_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, verbose=0):
|
||||||
image_embeddings = None
|
image_embeddings = None
|
||||||
image_metadata_embeddings = None
|
|
||||||
|
|
||||||
# Load pre-computed image embeddings from file if exists
|
# Load pre-computed image embeddings from file if exists
|
||||||
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
|
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
|
||||||
@@ -41,16 +48,7 @@ def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regen
|
|||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
print(f"Loaded pre-computed embeddings from {embeddings_file}")
|
print(f"Loaded pre-computed embeddings from {embeddings_file}")
|
||||||
|
|
||||||
# load pre-computed image metadata embedding file if exists
|
# Else compute the image embeddings from scratch, which can take a while
|
||||||
if resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate:
|
|
||||||
image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata")
|
|
||||||
if verbose > 0:
|
|
||||||
print(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
|
|
||||||
|
|
||||||
if image_embeddings is None or image_metadata_embeddings is None: # Else compute the image_embeddings from scratch, which can take a while
|
|
||||||
if verbose > 0:
|
|
||||||
print(f"Loading the {len(image_names)} images into memory")
|
|
||||||
|
|
||||||
if image_embeddings is None:
|
if image_embeddings is None:
|
||||||
image_embeddings = []
|
image_embeddings = []
|
||||||
for index in trange(0, len(image_names), batch_size):
|
for index in trange(0, len(image_names), batch_size):
|
||||||
@@ -60,7 +58,20 @@ def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regen
|
|||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
print(f"Saved computed embeddings to {embeddings_file}")
|
print(f"Saved computed embeddings to {embeddings_file}")
|
||||||
|
|
||||||
if image_metadata_embeddings is None:
|
return image_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_metadata_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0):
|
||||||
|
image_metadata_embeddings = None
|
||||||
|
|
||||||
|
# Load pre-computed image metadata embedding file if exists
|
||||||
|
if use_xmp_metadata and resolve_absolute_path(f"{embeddings_file}_metadata").exists() and not regenerate:
|
||||||
|
image_metadata_embeddings = torch.load(f"{embeddings_file}_metadata")
|
||||||
|
if verbose > 0:
|
||||||
|
print(f"Loaded pre-computed embeddings from {embeddings_file}_metadata")
|
||||||
|
|
||||||
|
# Else compute the image metadata embeddings from scratch, which can take a while
|
||||||
|
if use_xmp_metadata and image_metadata_embeddings is None:
|
||||||
image_metadata_embeddings = []
|
image_metadata_embeddings = []
|
||||||
for index in trange(0, len(image_names), batch_size):
|
for index in trange(0, len(image_names), batch_size):
|
||||||
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]]
|
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]]
|
||||||
@@ -69,7 +80,7 @@ def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regen
|
|||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
|
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
|
||||||
|
|
||||||
return image_embeddings, image_metadata_embeddings
|
return image_metadata_embeddings
|
||||||
|
|
||||||
|
|
||||||
def extract_metadata(image_name, verbose=0):
|
def extract_metadata(image_name, verbose=0):
|
||||||
@@ -102,13 +113,14 @@ def query_images(query, image_embeddings, image_metadata_embeddings, model, coun
|
|||||||
in util.semantic_search(query_embedding, image_embeddings, top_k=count)[0]}
|
in util.semantic_search(query_embedding, image_embeddings, top_k=count)[0]}
|
||||||
|
|
||||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||||
metadata_hits = {result['corpus_id']: result['score']
|
if image_metadata_embeddings:
|
||||||
for result
|
metadata_hits = {result['corpus_id']: result['score']
|
||||||
in util.semantic_search(query_embedding, image_metadata_embeddings, top_k=count)[0]}
|
for result
|
||||||
|
in util.semantic_search(query_embedding, image_metadata_embeddings, top_k=count)[0]}
|
||||||
|
|
||||||
# Sum metadata, image scores of the highest ranked images
|
# Sum metadata, image scores of the highest ranked images
|
||||||
for corpus_id, score in metadata_hits.items():
|
for corpus_id, score in metadata_hits.items():
|
||||||
image_hits[corpus_id] = image_hits.get(corpus_id, 0) + score
|
image_hits[corpus_id] = image_hits.get(corpus_id, 0) + score
|
||||||
|
|
||||||
# Reformat results in original form from sentence transformer semantic_search()
|
# Reformat results in original form from sentence transformer semantic_search()
|
||||||
hits = [{'corpus_id': corpus_id, 'score': score} for corpus_id, score in image_hits.items()]
|
hits = [{'corpus_id': corpus_id, 'score': score} for corpus_id, score in image_hits.items()]
|
||||||
@@ -138,7 +150,7 @@ def collate_results(hits, image_names, image_directory, count=5):
|
|||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
||||||
|
|
||||||
def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, verbose=0):
|
def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0):
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
model = initialize_model()
|
model = initialize_model()
|
||||||
|
|
||||||
@@ -148,7 +160,8 @@ def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, ver
|
|||||||
|
|
||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
embeddings_file = resolve_absolute_path(embeddings_file)
|
embeddings_file = resolve_absolute_path(embeddings_file)
|
||||||
image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file, batch_size=batch_size, regenerate=regenerate, verbose=verbose)
|
image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file,
|
||||||
|
batch_size=batch_size, regenerate=regenerate, use_xmp_metadata=use_xmp_metadata, verbose=verbose)
|
||||||
|
|
||||||
return image_names, image_embeddings, image_metadata_embeddings, model
|
return image_names, image_embeddings, image_metadata_embeddings, model
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,8 @@ default_config = {
|
|||||||
'image':
|
'image':
|
||||||
{
|
{
|
||||||
'embeddings-file': '.image_embeddings.pt',
|
'embeddings-file': '.image_embeddings.pt',
|
||||||
'batch-size': 50
|
'batch-size': 50,
|
||||||
|
'use-xmp-metadata': 'no'
|
||||||
},
|
},
|
||||||
'music':
|
'music':
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user