Make embeddings, jsonl paths absolute. Create directories if non-existent

This commit is contained in:
Debanjum Singh Solanky
2022-08-05 02:51:49 +03:00
parent d5b43eb836
commit 675e821d95
3 changed files with 31 additions and 7 deletions

View File

@@ -22,6 +22,12 @@ def initialize_model(search_config: ImageSearchConfig):
# Initialize Model
torch.set_num_threads(4)
# Convert model directory to absolute path
search_config.model_directory = resolve_absolute_path(search_config.model_directory)
# Create model directory if it doesn't exist
search_config.model_directory.parent.mkdir(parents=True, exist_ok=True)
# Load the CLIP model
encoder = load_model(
model_dir = search_config.model_directory,
@@ -61,9 +67,8 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
image_embeddings = torch.load(embeddings_file)
if verbose > 0:
print(f"Loaded pre-computed embeddings from {embeddings_file}")
# Else compute the image embeddings from scratch, which can take a while
if image_embeddings is None:
elif image_embeddings is None:
image_embeddings = []
for index in trange(0, len(image_names), batch_size):
images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]]
@@ -71,6 +76,11 @@ def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=5
images,
convert_to_tensor=True,
batch_size=min(len(images), batch_size))
# Create directory for embeddings file, if it doesn't exist
embeddings_file.parent.mkdir(parents=True, exist_ok=True)
# Save computed image embeddings to file
torch.save(image_embeddings, embeddings_file)
if verbose > 0:
print(f"Saved computed embeddings to {embeddings_file}")

View File

@@ -22,6 +22,12 @@ def initialize_model(search_config: TextSearchConfig):
# Number of entries we want to retrieve with the bi-encoder
top_k = 15
# Convert model directory to absolute path
search_config.model_directory = resolve_absolute_path(search_config.model_directory)
# Create model directory if it doesn't exist
search_config.model_directory.parent.mkdir(parents=True, exist_ok=True)
# The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model(
model_dir = search_config.model_directory,
@@ -47,7 +53,7 @@ def extract_entries(jsonl_file, verbose=0):
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, device='cpu', verbose=0):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
# Load pre-computed embeddings from file if exists
if resolve_absolute_path(embeddings_file).exists() and not regenerate:
if embeddings_file.exists() and not regenerate:
corpus_embeddings = torch.load(get_absolute_path(embeddings_file))
if verbose > 0:
print(f"Loaded embeddings from {embeddings_file}")
@@ -56,7 +62,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, show_progress_bar=True)
corpus_embeddings.to(device)
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
torch.save(corpus_embeddings, get_absolute_path(embeddings_file))
torch.save(corpus_embeddings, embeddings_file)
if verbose > 0:
print(f"Computed embeddings and saved them to {embeddings_file}")
@@ -165,7 +171,8 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
# Map notes in text files to (compressed) JSONL formatted file
if not resolve_absolute_path(config.compressed_jsonl).exists() or regenerate:
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
if not config.compressed_jsonl.exists() or regenerate:
text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, verbose)
# Extract Entries
@@ -173,6 +180,7 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
# Compute or Load Embeddings
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate, device=device, verbose=verbose)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k, verbose=verbose)