diff --git a/src/processor/ledger/beancount_to_jsonl.py b/src/processor/ledger/beancount_to_jsonl.py index c0136bc6..a1792735 100644 --- a/src/processor/ledger/beancount_to_jsonl.py +++ b/src/processor/ledger/beancount_to_jsonl.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) # Define Functions -def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file): +def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file, previous_entries=None): # Input Validation if is_none_or_empty(beancount_files) and is_none_or_empty(beancount_file_filter): print("At least one of beancount-files or beancount-file-filter is required to be specified") @@ -39,7 +39,7 @@ def beancount_to_jsonl(beancount_files, beancount_file_filter, output_file): elif output_file.suffix == ".jsonl": dump_jsonl(jsonl_data, output_file) - return entries + return list(enumerate(entries)) def get_beancount_files(beancount_files=None, beancount_file_filter=None): diff --git a/src/processor/markdown/markdown_to_jsonl.py b/src/processor/markdown/markdown_to_jsonl.py index a0903fcb..ce022358 100644 --- a/src/processor/markdown/markdown_to_jsonl.py +++ b/src/processor/markdown/markdown_to_jsonl.py @@ -39,7 +39,7 @@ def markdown_to_jsonl(markdown_files, markdown_file_filter, output_file): elif output_file.suffix == ".jsonl": dump_jsonl(jsonl_data, output_file) - return entries + return list(enumerate(entries)) def get_markdown_files(markdown_files=None, markdown_file_filter=None): diff --git a/src/processor/org_mode/org_to_jsonl.py b/src/processor/org_mode/org_to_jsonl.py index f1531797..1c77acf8 100644 --- a/src/processor/org_mode/org_to_jsonl.py +++ b/src/processor/org_mode/org_to_jsonl.py @@ -7,6 +7,7 @@ import argparse import pathlib import glob import logging +import hashlib # Internal Packages from src.processor.org_mode import orgnode @@ -19,7 +20,7 @@ logger = logging.getLogger(__name__) # Define Functions -def org_to_jsonl(org_files, org_file_filter, output_file): +def org_to_jsonl(org_files, org_file_filter, output_file, previous_entries=None): # Input Validation if is_none_or_empty(org_files) and is_none_or_empty(org_file_filter): print("At least one of org-files or org-file-filter is required to be specified") @@ -29,10 +30,41 @@ def org_to_jsonl(org_files, org_file_filter, output_file): org_files = get_org_files(org_files, org_file_filter) # Extract Entries from specified Org files - entries, file_to_entries = extract_org_entries(org_files) + entry_nodes, file_to_entries = extract_org_entries(org_files) + current_entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries) + + # Identify, mark and merge any new entries with previous entries + if not previous_entries: + entries_with_ids = list(enumerate(current_entries)) + else: + # Hash all current and previous entries to identify new entries + current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(json.dumps(e), encoding='utf-8')).hexdigest(), current_entries)) + previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(json.dumps(e), encoding='utf-8')).hexdigest(), previous_entries)) + + hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) + hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries)) + + # All entries that did not exist in the previous set are to be added + new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes) + # All entries that exist in both current and previous sets are kept + existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes) + + # Mark new entries with no ids for later embeddings generation + new_entries = [ + (None, hash_to_current_entries[entry_hash]) + for entry_hash in new_entry_hashes + ] + # Set id of existing entries to their previous ids to reuse their existing encoded embeddings + existing_entries = [ + (previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash]) + for entry_hash in existing_entry_hashes + ] + existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0]) + entries_with_ids = existing_entries_sorted + new_entries # Process Each Entry from All Notes Files - jsonl_data = convert_org_entries_to_jsonl(entries, file_to_entries) + entries = map(lambda entry: entry[1], entries_with_ids) + jsonl_data = convert_org_entries_to_jsonl(entries) # Compress JSONL formatted Data if output_file.suffix == ".gz": @@ -40,7 +72,7 @@ def org_to_jsonl(org_files, org_file_filter, output_file): elif output_file.suffix == ".jsonl": dump_jsonl(jsonl_data, output_file) - return entries + return entries_with_ids def get_org_files(org_files=None, org_file_filter=None): @@ -70,16 +102,16 @@ def extract_org_entries(org_files): entry_to_file_map = [] for org_file in org_files: org_file_entries = orgnode.makelist(str(org_file)) - entry_to_file_map += [org_file]*len(org_file_entries) + entry_to_file_map += zip(org_file_entries, [org_file]*len(org_file_entries)) entries.extend(org_file_entries) - return entries, entry_to_file_map + return entries, dict(entry_to_file_map) -def convert_org_entries_to_jsonl(entries, entry_to_file_map) -> str: - "Convert each Org-Mode entries to JSON and collate as JSONL" - jsonl = '' - for entry_id, entry in enumerate(entries): +def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map) -> list[dict]: + "Convert Org-Mode entries into list of dictionary" + entry_maps = [] + for entry in entries: entry_dict = dict() # Ignore title notes i.e notes with just headings and empty body @@ -113,14 +145,17 @@ def convert_org_entries_to_jsonl(entries, entry_to_file_map) -> str: if entry_dict: entry_dict["raw"] = f'{entry}' - entry_dict["file"] = f'{entry_to_file_map[entry_id]}' + entry_dict["file"] = f'{entry_to_file_map[entry]}' # Convert Dictionary to JSON and Append to JSONL string - jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n' + entry_maps.append(entry_dict) - logger.info(f"Converted {len(entries)} to jsonl format") + return entry_maps - return jsonl + +def convert_org_entries_to_jsonl(entries) -> str: + "Convert each Org-Mode entry to JSON and collate as JSONL" + return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) if __name__ == '__main__': diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 8666056c..b5e647e2 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -55,15 +55,28 @@ def extract_entries(jsonl_file): return load_jsonl(jsonl_file) -def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False): +def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" - # Load pre-computed embeddings from file if exists + new_entries = [] + # Load pre-computed embeddings from file if exists and update them if required if embeddings_file.exists() and not regenerate: corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=state.device) logger.info(f"Loaded embeddings from {embeddings_file}") - else: # Else compute the corpus_embeddings from scratch, which can take a while - corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=state.device, show_progress_bar=True) + # Encode any new entries in the corpus and update corpus embeddings + new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None] + if new_entries: + new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) + existing_entry_ids = [id for id, _ in entries_with_ids if id is not None] + existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor() + corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) + # Else compute the corpus embeddings from scratch + else: + new_entries = [entry['compiled'] for _, entry in entries_with_ids] + corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) + + # Save regenerated or updated embeddings to file + if new_entries: corpus_embeddings = util.normalize_embeddings(corpus_embeddings) torch.save(corpus_embeddings, embeddings_file) logger.info(f"Computed embeddings and saved them to {embeddings_file}") @@ -169,16 +182,16 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon # Map notes in text files to (compressed) JSONL formatted file config.compressed_jsonl = resolve_absolute_path(config.compressed_jsonl) - if not config.compressed_jsonl.exists() or regenerate: - text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl) + previous_entries = extract_entries(config.compressed_jsonl) if config.compressed_jsonl.exists() else None + entries_with_indices = text_to_jsonl(config.input_files, config.input_filter, config.compressed_jsonl, previous_entries) - # Extract Entries + # Extract Updated Entries entries = extract_entries(config.compressed_jsonl) top_k = min(len(entries), top_k) # top_k hits can't be more than the total entries in corpus # Compute or Load Embeddings config.embeddings_file = resolve_absolute_path(config.embeddings_file) - corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate) + corpus_embeddings = compute_embeddings(entries_with_indices, bi_encoder, config.embeddings_file, regenerate=regenerate) for filter in filters: filter.load(entries, regenerate=regenerate) diff --git a/tests/test_org_to_jsonl.py b/tests/test_org_to_jsonl.py index 6a626299..594c954f 100644 --- a/tests/test_org_to_jsonl.py +++ b/tests/test_org_to_jsonl.py @@ -3,7 +3,7 @@ import json from posixpath import split # Internal Packages -from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, extract_org_entries +from src.processor.org_mode.org_to_jsonl import convert_org_entries_to_jsonl, convert_org_nodes_to_entries, extract_org_entries from src.utils.helpers import is_none_or_empty @@ -21,10 +21,11 @@ def test_entry_with_empty_body_line_to_jsonl(tmp_path): # Act # Extract Entries from specified Org files - entries, entry_to_file_map = extract_org_entries(org_files=[orgfile]) + entry_nodes, file_to_entries = extract_org_entries(org_files=[orgfile]) # Process Each Entry from All Notes Files - jsonl_data = convert_org_entries_to_jsonl(entries, entry_to_file_map) + entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries) + jsonl_data = convert_org_entries_to_jsonl(entries) # Assert assert is_none_or_empty(jsonl_data) @@ -43,10 +44,11 @@ def test_entry_with_body_to_jsonl(tmp_path): # Act # Extract Entries from specified Org files - entries, entry_to_file_map = extract_org_entries(org_files=[orgfile]) + entry_nodes, file_to_entries = extract_org_entries(org_files=[orgfile]) # Process Each Entry from All Notes Files - jsonl_string = convert_org_entries_to_jsonl(entries, entry_to_file_map) + entries = convert_org_nodes_to_entries(entry_nodes, file_to_entries) + jsonl_string = convert_org_entries_to_jsonl(entries) jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()] # Assert