mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Batch encode images to keep memory consumption manageable
- Issue:
Process would get killed while encoding images
for consuming too much memory
- Fix:
- Encode images in batches and append to image_embeddings
- No need to use copy or deep_copy anymore with batch processing.
It would earlier throw too many files open error
Other Changes:
- Use tqdm to see progress even when using batch
- See progress bar of encoding independent of verbosity (for now)
This commit is contained in:
@@ -6,6 +6,7 @@ import copy
|
|||||||
# External Packages
|
# External Packages
|
||||||
from sentence_transformers import SentenceTransformer, util
|
from sentence_transformers import SentenceTransformer, util
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from tqdm import trange
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
@@ -50,23 +51,20 @@ def compute_embeddings(image_names, model, embeddings_file, regenerate=False, ve
|
|||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
print(f"Loading the {len(image_names)} images into memory")
|
print(f"Loading the {len(image_names)} images into memory")
|
||||||
|
|
||||||
|
batch_size = 50
|
||||||
if image_embeddings is None:
|
if image_embeddings is None:
|
||||||
image_embeddings = model.encode(
|
image_embeddings = []
|
||||||
[Image.open(image_name).copy() for image_name in image_names],
|
for index in trange(0, len(image_names), batch_size):
|
||||||
batch_size=128, convert_to_tensor=True, show_progress_bar=verbose > 0)
|
images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]]
|
||||||
|
image_embeddings += model.encode(images, convert_to_tensor=True, batch_size=batch_size)
|
||||||
torch.save(image_embeddings, embeddings_file)
|
torch.save(image_embeddings, embeddings_file)
|
||||||
|
|
||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
print(f"Saved computed embeddings to {embeddings_file}")
|
print(f"Saved computed embeddings to {embeddings_file}")
|
||||||
|
|
||||||
if image_metadata_embeddings is None:
|
if image_metadata_embeddings is None:
|
||||||
image_metadata_embeddings = model.encode(
|
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names],
|
||||||
[extract_metadata(image_name, verbose) for image_name in image_names],
|
image_metadata_embeddings = model.encode(image_metadata, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
|
||||||
batch_size=128, convert_to_tensor=True, show_progress_bar=verbose > 0)
|
|
||||||
|
|
||||||
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
|
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
|
||||||
|
|
||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
|
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user