mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Batch encode images to keep memory consumption manageable
- Issue:
Process would get killed while encoding images
for consuming too much memory
- Fix:
- Encode images in batches and append to image_embeddings
- No need to use copy or deep_copy anymore with batch processing.
It would earlier throw too many files open error
Other Changes:
- Use tqdm to see progress even when using batch
- See progress bar of encoding independent of verbosity (for now)
This commit is contained in:
@@ -6,6 +6,7 @@ import copy
|
||||
# External Packages
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
import torch
|
||||
|
||||
# Internal Packages
|
||||
@@ -50,23 +51,20 @@ def compute_embeddings(image_names, model, embeddings_file, regenerate=False, ve
|
||||
if verbose > 0:
|
||||
print(f"Loading the {len(image_names)} images into memory")
|
||||
|
||||
batch_size = 50
|
||||
if image_embeddings is None:
|
||||
image_embeddings = model.encode(
|
||||
[Image.open(image_name).copy() for image_name in image_names],
|
||||
batch_size=128, convert_to_tensor=True, show_progress_bar=verbose > 0)
|
||||
|
||||
image_embeddings = []
|
||||
for index in trange(0, len(image_names), batch_size):
|
||||
images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]]
|
||||
image_embeddings += model.encode(images, convert_to_tensor=True, batch_size=batch_size)
|
||||
torch.save(image_embeddings, embeddings_file)
|
||||
|
||||
if verbose > 0:
|
||||
print(f"Saved computed embeddings to {embeddings_file}")
|
||||
|
||||
if image_metadata_embeddings is None:
|
||||
image_metadata_embeddings = model.encode(
|
||||
[extract_metadata(image_name, verbose) for image_name in image_names],
|
||||
batch_size=128, convert_to_tensor=True, show_progress_bar=verbose > 0)
|
||||
|
||||
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names],
|
||||
image_metadata_embeddings = model.encode(image_metadata, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
|
||||
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
|
||||
|
||||
if verbose > 0:
|
||||
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user