mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 05:39:12 +00:00
Resolve edge case errors in encoding image metadata
- Handle case where current image batch smaller than batch_size - Handle case where no XMP metadata for current image - return empty strings in such a scenario instead of ". "
This commit is contained in:
@@ -65,7 +65,10 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
|
|||||||
image_embeddings = []
|
image_embeddings = []
|
||||||
for index in trange(0, len(image_names), batch_size):
|
for index in trange(0, len(image_names), batch_size):
|
||||||
images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]]
|
images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]]
|
||||||
image_embeddings += encoder.encode(images, convert_to_tensor=True, batch_size=batch_size)
|
image_embeddings += encoder.encode(
|
||||||
|
images,
|
||||||
|
convert_to_tensor=True,
|
||||||
|
batch_size=min(len(images), batch_size))
|
||||||
torch.save(image_embeddings, embeddings_file)
|
torch.save(image_embeddings, embeddings_file)
|
||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
print(f"Saved computed embeddings to {embeddings_file}")
|
print(f"Saved computed embeddings to {embeddings_file}")
|
||||||
@@ -88,7 +91,10 @@ def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_siz
|
|||||||
for index in trange(0, len(image_names), batch_size):
|
for index in trange(0, len(image_names), batch_size):
|
||||||
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]]
|
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]]
|
||||||
try:
|
try:
|
||||||
image_metadata_embeddings += encoder.encode(image_metadata, convert_to_tensor=True, batch_size=batch_size)
|
image_metadata_embeddings += encoder.encode(
|
||||||
|
image_metadata,
|
||||||
|
convert_to_tensor=True,
|
||||||
|
batch_size=min(len(image_metadata), batch_size))
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
|
print(f"Error encoding metadata for images starting from\n\tindex: {index},\n\timages: {image_names[index:index+batch_size]}\nException: {e}")
|
||||||
continue
|
continue
|
||||||
@@ -103,9 +109,14 @@ def extract_metadata(image_name, verbose=0):
|
|||||||
with exiftool.ExifTool() as et:
|
with exiftool.ExifTool() as et:
|
||||||
image_metadata = et.get_tags(["XMP:Subject", "XMP:Description"], str(image_name))
|
image_metadata = et.get_tags(["XMP:Subject", "XMP:Description"], str(image_name))
|
||||||
image_metadata_subjects = set([subject.split(":")[1] for subject in image_metadata.get("XMP:Subject", "") if ":" in subject])
|
image_metadata_subjects = set([subject.split(":")[1] for subject in image_metadata.get("XMP:Subject", "") if ":" in subject])
|
||||||
image_processed_metadata = image_metadata.get("XMP:Description", "") + ". " + ", ".join(image_metadata_subjects)
|
|
||||||
|
image_processed_metadata = image_metadata.get("XMP:Description", "")
|
||||||
|
if len(image_metadata_subjects) > 0:
|
||||||
|
image_processed_metadata += ". " + ", ".join(image_metadata_subjects)
|
||||||
|
|
||||||
if verbose > 2:
|
if verbose > 2:
|
||||||
print(f"{image_name}:\t{image_processed_metadata}")
|
print(f"{image_name}:\t{image_processed_metadata}")
|
||||||
|
|
||||||
return image_processed_metadata
|
return image_processed_metadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user