diff --git a/tests/test_text_search.py b/tests/test_text_search.py index c18a4c42..5809f327 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -5,6 +5,7 @@ import os # External Packages import pytest +import torch from khoj.utils.config import SearchModels # Internal Packages @@ -202,3 +203,25 @@ def test_asymmetric_setup_github(content_config: ContentConfig, search_models: S # Assert assert len(github_model.entries) > 1 + + +def compare_index(initial_notes_model, final_notes_model): + mismatched_entries, mismatched_embeddings = [], [] + for index in range(len(initial_notes_model.entries)): + if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json(): + mismatched_entries.append(index) + + # verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings + for index in range(len(initial_notes_model.corpus_embeddings)): + if not torch.equal(final_notes_model.corpus_embeddings[index], initial_notes_model.corpus_embeddings[index]): + mismatched_embeddings.append(index) + + error_details = "" + if mismatched_entries: + mismatched_entries_str = ",".join(map(str, mismatched_entries)) + error_details += f"Entries at {mismatched_entries_str} not equal\n" + if mismatched_embeddings: + mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings)) + error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n" + + return error_details