mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 21:29:12 +00:00
Support incremental update of org-mode entries and embeddings
- What
- Hash the entries and compare to find new/updated entries
- Reuse embeddings encoded for existing entries
- Only encode embeddings for updated or new entries
- Merge the existing and new entries and embeddings to get the updated
entries, embeddings
- Why
- Given most note text entries are expected to be unchanged
across time. Reusing their earlier encoded embeddings should
significantly speed up embeddings updates
- Previously we were regenerating embeddings for all entries,
even if they had existed in previous runs
This commit is contained in:
@@ -55,15 +55,28 @@ def extract_entries(jsonl_file):
|
||||
return load_jsonl(jsonl_file)
|
||||
|
||||
|
||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
|
||||
def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
# Load pre-computed embeddings from file if exists
|
||||
new_entries = []
|
||||
# Load pre-computed embeddings from file if exists and update them if required
|
||||
if embeddings_file.exists() and not regenerate:
|
||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device)
|
||||
logger.info(f"Loaded embeddings from {embeddings_file}")
|
||||
|
||||
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
# Encode any new entries in the corpus and update corpus embeddings
|
||||
new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None]
|
||||
if new_entries:
|
||||
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
existing_entry_ids = [id for id, _ in entries_with_ids if id is not None]
|
||||
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor()
|
||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||
# Else compute the corpus embeddings from scratch
|
||||
else:
|
||||
new_entries = [entry['compiled'] for _, entry in entries_with_ids]
|
||||
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
|
||||
# Save regenerated or updated embeddings to file
|
||||
if new_entries:
|
||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||
torch.save(corpus_embeddings, embeddings_file)
|
||||
logger.info(f"Computed embeddings and saved them to {embeddings_file}")
|
||||
@@ -169,16 +182,16 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
|
||||
|
||||
# Map notes in text files to (compressed) JSONL formatted file
|
||||
config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl)
|
||||
if not config.compressed_jsonl.exists() or regenerate:
|
||||
text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl)
|
||||
previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() else None
|
||||
entries_with_indices = text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, previous_entries)
|
||||
|
||||
# Extract Entries
|
||||
# Extract Updated Entries
|
||||
entries = extract_entries(config.compressed_jsonl)
|
||||
top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus
|
||||
|
||||
# Compute or Load Embeddings
|
||||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
||||
corpus_embeddings = compute_embeddings(entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
||||
|
||||
for filter in filters:
|
||||
filter.load(entries, regenerate=regenerate)
|
||||
|
||||
Reference in New Issue
Block a user