Improve Indexing Text Entries (#535)

Major
- Ensure search results logic consistent across migration to DB, multi-user
- Manually verified search results for sample queries look the same across migration
 - Flatten indexing code for better indexing progress tracking and code readability

Minor
- a4f407f Test memory leak on MPS device when generating vector embeddings
- ef24485 Improve Khoj with DB setup instructions in the Django app readme (for now)
- f212cc7 Arrange remaining text search tests in arrange, act, assert order
- 022017d Fix text search tests to test updated indexing log messages
This commit is contained in:
Debanjum
2023-11-06 16:01:53 -08:00
committed by GitHub
11 changed files with 199 additions and 134 deletions

View File

@@ -1,3 +1,14 @@
# Standard Packages
import numpy as np
import psutil
from scipy.stats import linregress
import secrets
# External Packages
import pytest
# Internal Packages
from khoj.processor.embeddings import EmbeddingsModel
from khoj.utils import helpers
@@ -44,3 +55,29 @@ def test_lru_cache():
cache["b"] # accessing 'b' makes it the most recently used item
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
assert cache == {"b": 2, "d": 4}
@pytest.mark.skip(reason="Memory leak exists on GPU, MPS devices")
def test_encode_docs_memory_leak():
# Arrange
iterations = 50
batch_size = 20
embeddings_model = EmbeddingsModel()
memory_usage_trend = []
# Act
# Encode random strings repeatedly and record memory usage trend
for iteration in range(iterations):
random_docs = [" ".join(secrets.token_hex(5) for _ in range(10)) for _ in range(batch_size)]
a = [embeddings_model.embed_documents(random_docs)]
memory_usage_trend += [psutil.Process().memory_info().rss / (1024 * 1024)]
print(f"{iteration:02d}, {memory_usage_trend[-1]:.2f}", flush=True)
# Calculate slope of line fitting memory usage history
memory_usage_trend = np.array(memory_usage_trend)
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
# Assert
# If slope is positive memory utilization is increasing
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
assert slope < 2, f"Memory usage increasing at ~{slope:.2f} MB per iteration"

View File

@@ -48,10 +48,11 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
user=default_user,
)
# Act
org_files = collect_files(user=default_user)["org"]
# Act
# should not raise IsADirectoryError and return orgfile
# Assert
# should return orgfile and not raise IsADirectoryError
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
@@ -62,12 +63,14 @@ def test_text_search_setup_with_empty_file_raises_error(
):
# Arrange
data = get_org_files(org_config_with_only_new_file)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
# Assert
assert "Deleted 3 entries. Created 0 new entries for user " in caplog.records[-1].message
verify_embeddings(0, default_user)
@@ -79,12 +82,15 @@ def test_text_indexer_deletes_embedding_before_regenerate(
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
assert "Deleting all embeddings for file type org" in caplog.text
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
assert "Deleting all entries for file type org" in caplog.text
assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@@ -93,13 +99,14 @@ def test_text_search_setup_batch_processes(content_config: ContentConfig, defaul
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
assert "Created 4 new embeddings" in caplog.text
assert "Created 6 new embeddings" in caplog.text
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@@ -122,8 +129,8 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
final_logs = caplog.text
# Assert
assert "Deleting all embeddings for file type org" in initial_logs
assert "Deleting all embeddings for file type org" not in final_logs
assert "Deleting all entries for file type org" in initial_logs
assert "Deleting all entries for file type org" not in final_logs
# ----------------------------------------------------------------------------------------------------
@@ -135,7 +142,6 @@ async def test_text_search(search_config: SearchConfig):
default_user = await KhojUser.objects.acreate(
username="test_user", password="test_password", email="test@example.com"
)
# Arrange
org_config = await LocalOrgConfig.objects.acreate(
input_files=None,
input_filter=["tests/data/org/*.org"],
@@ -159,13 +165,12 @@ async def test_text_search(search_config: SearchConfig):
# Act
hits = await text_search.query(default_user, query)
# Assert
results = text_search.collate_results(hits)
results = sorted(results, key=lambda x: float(x.score))[:1]
# search results should contain "git clone" entry
# Assert
search_result = results[0].entry
assert "git clone" in search_result
assert "git clone" in search_result, 'search result did not contain "git clone" entry'
# ----------------------------------------------------------------------------------------------------
@@ -188,8 +193,9 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
# Assert
# verify newly added org-mode entry is split by max tokens
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
assert (
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens"
# ----------------------------------------------------------------------------------------------------
@@ -245,8 +251,9 @@ conda activate khoj
)
# Assert
# verify newly added org-mode entry is split by max tokens
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
assert (
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens"
# ----------------------------------------------------------------------------------------------------
@@ -256,27 +263,29 @@ def test_regenerate_index_with_new_entry(
):
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
initial_data = get_org_files(org_config)
# append org-mode entry to first org input file in config
org_config.input_files = [f"{new_org_file}"]
with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
data = get_org_files(org_config)
final_data = get_org_files(org_config)
# Act
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# regenerate notes jsonl, model embeddings and model to include entry from new file
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
final_logs = caplog.text
# Assert
assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message
assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
assert "Deleted 10 entries. Created 11 new entries for user " in final_logs
verify_embeddings(11, default_user)
@@ -311,8 +320,8 @@ def test_update_index_with_duplicate_entries_in_stable_order(
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in initial_logs
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in final_logs
assert "Deleted 3 entries. Created 1 new entries for user " in initial_logs
assert "Deleted 0 entries. Created 0 new entries for user " in final_logs
verify_embeddings(1, default_user)
@@ -327,29 +336,29 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}{new_entry} -- Tatooine")
data = get_org_files(org_config_with_only_new_file)
# load embeddings, entries, notes model after adding new org file with 2 entries
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
initial_data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}")
data = get_org_files(org_config_with_only_new_file)
final_data = get_org_files(org_config_with_only_new_file)
# Act
# load embeddings, entries, notes model after adding new org file with 2 entries
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in initial_logs
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in final_logs
assert "Deleted 3 entries. Created 2 new entries for user " in initial_logs
assert "Deleted 1 entries. Created 0 new entries for user " in final_logs
verify_embeddings(1, default_user)
@@ -379,9 +388,8 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
final_logs = caplog.text
# Assert
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in initial_logs
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in final_logs
assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
assert "Deleted 0 entries. Created 1 new entries for user " in final_logs
verify_embeddings(11, default_user)
@@ -390,6 +398,7 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
# Arrange
github_config = GithubConfig.objects.filter(user=default_user).first()
# Act
# Regenerate github embeddings to test asymmetric setup without caching
text_search.setup(