Drop embeddings of deleted text entries from index

Previously the deleted embeddings would continue to be in the index,
even after the entry was deleted
This commit is contained in:
Debanjum Singh Solanky
2023-07-16 03:47:05 -07:00
parent c73feebf25
commit ef6a0044f4
2 changed files with 54 additions and 17 deletions

View File

@@ -65,7 +65,8 @@ def compute_embeddings(
normalize=True,
):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = []
new_embeddings = torch.tensor([], device=state.device)
existing_embeddings = torch.tensor([], device=state.device)
create_index_msg = ""
# Load pre-computed embeddings from file if exists and update them if required
if embeddings_file.exists() and not regenerate:
@@ -82,22 +83,23 @@ def compute_embeddings(
new_embeddings = bi_encoder.encode(
new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True
)
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
if existing_entry_ids:
existing_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
)
else:
existing_embeddings = torch.tensor([], device=state.device)
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
if normalize:
# Normalize embeddings for faster lookup via dot product when querying
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
# Extract existing embeddings from previous corpus embeddings
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
if existing_entry_ids:
existing_embeddings = torch.index_select(
corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)
)
# Save regenerated or updated embeddings to file
torch.save(corpus_embeddings, embeddings_file)
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
# Set corpus embeddings to merger of existing and new embeddings
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
if normalize:
# Normalize embeddings for faster lookup via dot product when querying
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
# Save regenerated or updated embeddings to file
torch.save(corpus_embeddings, embeddings_file)
logger.info(f"📩 Saved computed text embeddings to {embeddings_file}")
return corpus_embeddings