diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index c123974a..09174186 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -65,7 +65,8 @@ def compute_embeddings( normalize=True, ): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" - new_entries = [] + new_embeddings = torch.tensor([], device=state.device) + existing_embeddings = torch.tensor([], device=state.device) create_index_msg = "" # Load pre-computed embeddings from file if exists and update them if required if embeddings_file.exists() and not regenerate: @@ -82,22 +83,23 @@ def compute_embeddings( new_embeddings = bi_encoder.encode( new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True ) - existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] - if existing_entry_ids: - existing_embeddings = torch.index_select( - corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device) - ) - else: - existing_embeddings = torch.tensor([], device=state.device) - corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) - if normalize: - # Normalize embeddings for faster lookup via dot product when querying - corpus_embeddings = util.normalize_embeddings(corpus_embeddings) + # Extract existing embeddings from previous corpus embeddings + existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] + if existing_entry_ids: + existing_embeddings = torch.index_select( + corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device) + ) - # Save regenerated or updated embeddings to file - torch.save(corpus_embeddings, embeddings_file) - logger.info(f"📩 Saved computed text embeddings to {embeddings_file}") + # Set corpus embeddings to merger of existing and new embeddings + corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) + if normalize: + # Normalize embeddings for faster lookup via dot product when querying + corpus_embeddings = util.normalize_embeddings(corpus_embeddings) + + # Save regenerated or updated embeddings to file + torch.save(corpus_embeddings, embeddings_file) + logger.info(f"📩 Saved computed text embeddings to {embeddings_file}") return corpus_embeddings diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 830b0da5..1ae7e770 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -71,8 +71,8 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, sea final_logs = caplog.text # Assert - assert "📩 Saved computed text embeddings to" in initial_logs - assert "📩 Saved computed text embeddings to" not in final_logs + assert "Creating index from scratch." in initial_logs + assert "Creating index from scratch." not in final_logs # ---------------------------------------------------------------------------------------------------- @@ -192,6 +192,41 @@ def test_update_index_with_duplicate_entries_in_stable_order( pytest.fail(error_details) +# ---------------------------------------------------------------------------------------------------- +def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels): + # Arrange + new_file_to_index = Path(org_config_with_only_new_file.input_files[0]) + + # Insert org-mode entries with same compiled form into new org file + new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" + with open(new_file_to_index, "w") as f: + f.write(f"{new_entry}{new_entry} -- Tatooine") + + # load embeddings, entries, notes model after adding new org file with 2 entries + initial_index = text_search.setup( + OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True + ) + + # update embeddings, entries, notes model after removing an entry from the org file + with open(new_file_to_index, "w") as f: + f.write(f"{new_entry}") + + # Act + updated_index = text_search.setup( + OrgToJsonl, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False + ) + + # Assert + # verify only 1 entry added even if there are multiple duplicate entries + assert len(initial_index.entries) == len(updated_index.entries) + 1 + assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1 + + # verify the same entry is added even when there are multiple duplicate entries + error_details = compare_index(updated_index, initial_index) + if error_details: + pytest.fail(error_details) + + # ---------------------------------------------------------------------------------------------------- def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path): # Arrange