mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Test memory leak on MPS device when generating vector embeddings
Slope threshold of 2.0 determined qualitatively on local Mac device Minor unused import and clean-up
This commit is contained in:
@@ -1,3 +1,14 @@
|
||||
# Standard Packages
|
||||
import numpy as np
|
||||
import psutil
|
||||
from scipy.stats import linregress
|
||||
import secrets
|
||||
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.embeddings import EmbeddingsModel
|
||||
from khoj.utils import helpers
|
||||
|
||||
|
||||
@@ -44,3 +55,29 @@ def test_lru_cache():
|
||||
cache["b"] # accessing 'b' makes it the most recently used item
|
||||
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
|
||||
assert cache == {"b": 2, "d": 4}
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Memory leak exists on GPU, MPS devices")
|
||||
def test_encode_docs_memory_leak():
|
||||
# Arrange
|
||||
iterations = 50
|
||||
batch_size = 20
|
||||
embeddings_model = EmbeddingsModel()
|
||||
memory_usage_trend = []
|
||||
|
||||
# Act
|
||||
# Encode random strings repeatedly and record memory usage trend
|
||||
for iteration in range(iterations):
|
||||
random_docs = [" ".join(secrets.token_hex(5) for _ in range(10)) for _ in range(batch_size)]
|
||||
a = [embeddings_model.embed_documents(random_docs)]
|
||||
memory_usage_trend += [psutil.Process().memory_info().rss / (1024 * 1024)]
|
||||
print(f"{iteration:02d}, {memory_usage_trend[-1]:.2f}", flush=True)
|
||||
|
||||
# Calculate slope of line fitting memory usage history
|
||||
memory_usage_trend = np.array(memory_usage_trend)
|
||||
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
|
||||
|
||||
# Assert
|
||||
# If slope is positive memory utilization is increasing
|
||||
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
|
||||
assert slope < 2, f"Memory usage increasing at ~{slope:.2f} MB per iteration"
|
||||
|
||||
Reference in New Issue
Block a user