From da98b92dd46e6eabadcf5df54f81af5a0c1cf734 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 15 Jul 2023 14:33:15 -0700 Subject: [PATCH] Create helper function to test value, order of entries & embeddings This helper should be used to observe if the current embeddings are stable sorted on regenerate and incremental update of index in text search tests --- tests/test_text_search.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_text_search.py b/tests/test_text_search.py index c18a4c42..5809f327 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -5,6 +5,7 @@ import os # External Packages import pytest +import torch from khoj.utils.config import SearchModels # Internal Packages @@ -202,3 +203,25 @@ def test_asymmetric_setup_github(content_config: ContentConfig, search_models: S # Assert assert len(github_model.entries) > 1 + + +def compare_index(initial_notes_model, final_notes_model): + mismatched_entries, mismatched_embeddings = [], [] + for index in range(len(initial_notes_model.entries)): + if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json(): + mismatched_entries.append(index) + + # verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings + for index in range(len(initial_notes_model.corpus_embeddings)): + if not torch.equal(final_notes_model.corpus_embeddings[index], initial_notes_model.corpus_embeddings[index]): + mismatched_embeddings.append(index) + + error_details = "" + if mismatched_entries: + mismatched_entries_str = ",".join(map(str, mismatched_entries)) + error_details += f"Entries at {mismatched_entries_str} not equal\n" + if mismatched_embeddings: + mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings)) + error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n" + + return error_details