Improve Indexing Text Entries (#535)

Major
- Ensure search results logic consistent across migration to DB, multi-user
- Manually verified search results for sample queries look the same across migration
 - Flatten indexing code for better indexing progress tracking and code readability

Minor
- a4f407f Test memory leak on MPS device when generating vector embeddings
- ef24485 Improve Khoj with DB setup instructions in the Django app readme (for now)
- f212cc7 Arrange remaining text search tests in arrange, act, assert order
- 022017d Fix text search tests to test updated indexing log messages
This commit is contained in:
Debanjum
2023-11-06 16:01:53 -08:00
committed by GitHub
11 changed files with 199 additions and 134 deletions

View File

@@ -93,6 +93,7 @@ test = [
"factory-boy >= 3.2.1", "factory-boy >= 3.2.1",
"trio >= 0.22.0", "trio >= 0.22.0",
"pytest-xdist", "pytest-xdist",
"psutil >= 5.8.0",
] ]
dev = [ dev = [
"khoj-assistant[test]", "khoj-assistant[test]",

View File

@@ -17,16 +17,26 @@ docker-compose up
## Setup (Local) ## Setup (Local)
### Install dependencies ### Install Postgres (with PgVector)
#### MacOS
- Install the [Postgres.app](https://postgresapp.com/).
#### Debian, Ubuntu
From [official instructions](https://wiki.postgresql.org/wiki/Apt)
```bash ```bash
pip install -e '.[dev]' sudo apt install -y postgresql-common
sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh
sudo apt install postgres-16 postgresql-16-pgvector
``` ```
### Setup the database #### Windows
- Use the [recommended installer](https://www.postgresql.org/download/windows/)
1. Ensure you have Postgres installed. For MacOS, you can use [Postgres.app](https://postgresapp.com/). #### From Source
2. If you're not using Postgres.app, you may have to install the pgvector extension manually. You can find the instructions [here](https://github.com/pgvector/pgvector#installation). If you're using Postgres.app, you can skip this step. Reproduced instructions below for convenience. 1. Follow instructions to [Install Postgres](https://www.postgresql.org/download/)
2. Follow instructions to [Install PgVector](https://github.com/pgvector/pgvector#installation) in case you need to manually install it. Reproduced instructions below for convenience.
```bash ```bash
cd /tmp cd /tmp
@@ -35,32 +45,50 @@ cd pgvector
make make
make install # may need sudo make install # may need sudo
``` ```
3. Create a database
### Create the khoj database ### Create the Khoj database
#### MacOS
```bash ```bash
createdb khoj -U postgres createdb khoj -U postgres
``` ```
### Make migrations #### Debian, Ubuntu
```bash
sudo -u postgres createdb khoj
```
This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted). - [Optional] To set default postgres user's password
- Execute `ALTER USER postgres PASSWORD 'my_secure_password';` using `psql`
- Run `export $POSTGRES_PASSWORD=my_secure_password` in your terminal for Khoj to use it later
### Install Khoj
```bash
pip install -e '.[dev]'
```
### Make Khoj DB migrations
This command will create the migrations for the database app. This command should be run whenever a new db model is added to the database app or an existing db model is modified (updated or deleted).
```bash ```bash
python3 src/manage.py makemigrations python3 src/manage.py makemigrations
``` ```
### Run migrations ### Run Khoj DB migrations
This command will run any pending migrations in your application. This command will run any pending migrations in your application.
```bash ```bash
python3 src/manage.py migrate python3 src/manage.py migrate
``` ```
### Run the server ### Start Khoj Server
While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend. While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend.
*Note: Anonymous mode bypasses authentication for local, single-user usage.*
```bash ```bash
python3 src/khoj/main.py python3 src/khoj/main.py --anonymous-mode
``` ```

View File

@@ -1,4 +1,3 @@
import secrets
from typing import Type, TypeVar, List from typing import Type, TypeVar, List
from datetime import date from datetime import date
import secrets import secrets
@@ -36,9 +35,6 @@ from database.models import (
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
) )
from khoj.utils.helpers import generate_random_name from khoj.utils.helpers import generate_random_name
from khoj.utils.rawconfig import (
ConversationProcessorConfig as UserConversationProcessorConfig,
)
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter

View File

@@ -1,7 +1,6 @@
from typing import List from typing import List
from langchain.embeddings import HuggingFaceEmbeddings from sentence_transformers import SentenceTransformer, CrossEncoder
from sentence_transformers import CrossEncoder
from khoj.utils.helpers import get_device from khoj.utils.helpers import get_device
from khoj.utils.rawconfig import SearchResponse from khoj.utils.rawconfig import SearchResponse
@@ -9,18 +8,16 @@ from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel: class EmbeddingsModel:
def __init__(self): def __init__(self):
self.encode_kwargs = {"normalize_embeddings": True}
self.model_kwargs = {"device": get_device()}
self.model_name = "thenlper/gte-small" self.model_name = "thenlper/gte-small"
encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True} self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
model_kwargs = {"device": get_device()}
self.embeddings_model = HuggingFaceEmbeddings(
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
)
def embed_query(self, query): def embed_query(self, query):
return self.embeddings_model.embed_query(query) return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
def embed_documents(self, docs): def embed_documents(self, docs):
return self.embeddings_model.embed_documents(docs) return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
class CrossEncoderModel: class CrossEncoderModel:

View File

@@ -24,7 +24,7 @@ class OrgToEntries(TextToEntries):
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]: ) -> Tuple[int, int]:
# Extract required fields from config # Extract required fields from config
index_heading_entries = True index_heading_entries = False
if not full_corpus: if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""]) deletion_file_names = set([file for file in files if files[file] == ""])

View File

@@ -1,11 +1,12 @@
# Standard Packages # Standard Packages
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import hashlib import hashlib
from itertools import repeat
import logging import logging
import uuid import uuid
from tqdm import tqdm from tqdm import tqdm
from typing import Callable, List, Tuple, Set, Any from typing import Callable, List, Tuple, Set, Any
from khoj.utils.helpers import timer, batcher from khoj.utils.helpers import is_none_or_empty, timer, batcher
# Internal Packages # Internal Packages
@@ -83,92 +84,88 @@ class TextToEntries(ABC):
user: KhojUser = None, user: KhojUser = None,
regenerate: bool = False, regenerate: bool = False,
): ):
with timer("Construct current entry hashes", logger): with timer("Constructed current entry hashes in", logger):
hashes_by_file = dict[str, set[str]]() hashes_by_file = dict[str, set[str]]()
current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries)) current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries))
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
for entry in tqdm(current_entries, desc="Hashing Entries"): for entry in tqdm(current_entries, desc="Hashing Entries"):
hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry)) hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry))
num_deleted_embeddings = 0 num_deleted_entries = 0
with timer("Preparing dataset for regeneration", logger): if regenerate:
if regenerate: with timer("Prepared dataset for regeneration in", logger):
logger.debug(f"Deleting all embeddings for file type {file_type}") logger.debug(f"Deleting all entries for file type {file_type}")
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type) num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
num_new_embeddings = 0 hashes_to_process = set()
with timer("Identify hashes for adding new entries", logger): with timer("Identified entries to add to database in", logger):
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"): for file in tqdm(hashes_by_file, desc="Identify new entries"):
hashes_for_file = hashes_by_file[file] hashes_for_file = hashes_by_file[file]
hashes_to_process = set()
existing_entries = DbEntry.objects.filter( existing_entries = DbEntry.objects.filter(
user=user, hashed_value__in=hashes_for_file, file_type=file_type user=user, hashed_value__in=hashes_for_file, file_type=file_type
) )
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries]) existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
hashes_to_process = hashes_for_file - existing_entry_hashes hashes_to_process |= hashes_for_file - existing_entry_hashes
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process] embeddings = []
data_to_embed = [getattr(entry, key) for entry in entries_to_process] with timer("Generated embeddings for entries to add to database in", logger):
embeddings = self.embeddings_model.embed_documents(data_to_embed) entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings += self.embeddings_model.embed_documents(data_to_embed)
with timer("Update the database with new vector embeddings", logger): added_entries: list[DbEntry] = []
num_items = len(hashes_to_process) with timer("Added entries to database in", logger):
assert num_items == len(embeddings) num_items = len(hashes_to_process)
batch_size = min(200, num_items) assert num_items == len(embeddings)
entry_batches = zip(hashes_to_process, embeddings) batch_size = min(200, num_items)
entry_batches = zip(hashes_to_process, embeddings)
for entry_batch in tqdm( for entry_batch in tqdm(batcher(entry_batches, batch_size), desc="Add entries to database"):
batcher(entry_batches, batch_size), desc="Processing embeddings in batches" batch_embeddings_to_create = []
): for entry_hash, new_entry in entry_batch:
batch_embeddings_to_create = [] entry = hash_to_current_entries[entry_hash]
for entry_hash, new_entry in entry_batch: batch_embeddings_to_create.append(
entry = hash_to_current_entries[entry_hash] DbEntry(
batch_embeddings_to_create.append( user=user,
DbEntry( embeddings=new_entry,
user=user, raw=entry.raw,
embeddings=new_entry, compiled=entry.compiled,
raw=entry.raw, heading=entry.heading[:1000], # Truncate to max chars of field allowed
compiled=entry.compiled, file_path=entry.file,
heading=entry.heading[:1000], # Truncate to max chars of field allowed file_type=file_type,
file_path=entry.file, hashed_value=entry_hash,
file_type=file_type, corpus_id=entry.corpus_id,
hashed_value=entry_hash, )
corpus_id=entry.corpus_id, )
) added_entries += DbEntry.objects.bulk_create(batch_embeddings_to_create)
) logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
new_entries = DbEntry.objects.bulk_create(batch_embeddings_to_create)
logger.debug(f"Created {len(new_entries)} new embeddings")
num_new_embeddings += len(new_entries)
dates_to_create = [] new_dates = []
with timer("Create new date associations for new embeddings", logger): with timer("Indexed dates from added entries in", logger):
for new_entry in new_entries: for added_entry in added_entries:
dates = self.date_filter.extract_dates(new_entry.raw) dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
for date in dates: dates_to_create = [
dates_to_create.append( EntryDates(date=date, entry=added_entry)
EntryDates( for date, added_entry in dates_in_entries
date=date, if not is_none_or_empty(date)
entry=new_entry, ]
) new_dates += EntryDates.objects.bulk_create(dates_to_create)
) logger.debug(f"Indexed {len(new_dates)} dates from added {file_type} entries")
new_dates = EntryDates.objects.bulk_create(dates_to_create)
if len(new_dates) > 0:
logger.debug(f"Created {len(new_dates)} new date entries")
with timer("Identify hashes for removed entries", logger): with timer("Deleted entries identified by server from database in", logger):
for file in hashes_by_file: for file in hashes_by_file:
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file) existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file] to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
num_deleted_embeddings += len(to_delete_entry_hashes) num_deleted_entries += len(to_delete_entry_hashes)
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes)) EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
with timer("Identify hashes for deleting entries", logger): with timer("Deleted entries requested by clients from database in", logger):
if deletion_filenames is not None: if deletion_filenames is not None:
for file_path in deletion_filenames: for file_path in deletion_filenames:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path) deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_embeddings += deleted_count num_deleted_entries += deleted_count
return num_new_embeddings, num_deleted_embeddings return len(added_entries), num_deleted_entries
@staticmethod @staticmethod
def mark_entries_for_update( def mark_entries_for_update(

View File

@@ -321,7 +321,6 @@ def load_content(
content_index: Optional[ContentIndex], content_index: Optional[ContentIndex],
search_models: SearchModels, search_models: SearchModels,
): ):
logger.info(f"Loading content from existing embeddings...")
if content_config is None: if content_config is None:
logger.warning("🚨 No Content configuration available.") logger.warning("🚨 No Content configuration available.")
return None return None

View File

@@ -207,7 +207,7 @@ def setup(
file_names = [file_name for file_name in files] file_names = [file_name for file_name in files]
logger.info( logger.info(
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}" f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
) )

View File

@@ -51,7 +51,8 @@ def cli(args=None):
args, remaining_args = parser.parse_known_args(args) args, remaining_args = parser.parse_known_args(args)
logger.debug(f"Ignoring unknown commandline args: {remaining_args}") if len(remaining_args) > 0:
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
# Set default values for arguments # Set default values for arguments
args.chat_on_gpu = not args.disable_chat_on_gpu args.chat_on_gpu = not args.disable_chat_on_gpu

View File

@@ -1,3 +1,14 @@
# Standard Packages
import numpy as np
import psutil
from scipy.stats import linregress
import secrets
# External Packages
import pytest
# Internal Packages
from khoj.processor.embeddings import EmbeddingsModel
from khoj.utils import helpers from khoj.utils import helpers
@@ -44,3 +55,29 @@ def test_lru_cache():
cache["b"] # accessing 'b' makes it the most recently used item cache["b"] # accessing 'b' makes it the most recently used item
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b' cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
assert cache == {"b": 2, "d": 4} assert cache == {"b": 2, "d": 4}
@pytest.mark.skip(reason="Memory leak exists on GPU, MPS devices")
def test_encode_docs_memory_leak():
# Arrange
iterations = 50
batch_size = 20
embeddings_model = EmbeddingsModel()
memory_usage_trend = []
# Act
# Encode random strings repeatedly and record memory usage trend
for iteration in range(iterations):
random_docs = [" ".join(secrets.token_hex(5) for _ in range(10)) for _ in range(batch_size)]
a = [embeddings_model.embed_documents(random_docs)]
memory_usage_trend += [psutil.Process().memory_info().rss / (1024 * 1024)]
print(f"{iteration:02d}, {memory_usage_trend[-1]:.2f}", flush=True)
# Calculate slope of line fitting memory usage history
memory_usage_trend = np.array(memory_usage_trend)
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
# Assert
# If slope is positive memory utilization is increasing
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
assert slope < 2, f"Memory usage increasing at ~{slope:.2f} MB per iteration"

View File

@@ -48,10 +48,11 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
user=default_user, user=default_user,
) )
# Act
org_files = collect_files(user=default_user)["org"] org_files = collect_files(user=default_user)["org"]
# Act # Assert
# should not raise IsADirectoryError and return orgfile # should return orgfile and not raise IsADirectoryError
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"} assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
@@ -62,12 +63,14 @@ def test_text_search_setup_with_empty_file_raises_error(
): ):
# Arrange # Arrange
data = get_org_files(org_config_with_only_new_file) data = get_org_files(org_config_with_only_new_file)
# Act # Act
# Generate notes embeddings during asymmetric setup # Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message # Assert
assert "Deleted 3 entries. Created 0 new entries for user " in caplog.records[-1].message
verify_embeddings(0, default_user) verify_embeddings(0, default_user)
@@ -79,12 +82,15 @@ def test_text_indexer_deletes_embedding_before_regenerate(
# Arrange # Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first() org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config) data = get_org_files(org_config)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.DEBUG): with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert # Assert
assert "Deleting all embeddings for file type org" in caplog.text assert "Deleting all entries for file type org" in caplog.text
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -93,13 +99,14 @@ def test_text_search_setup_batch_processes(content_config: ContentConfig, defaul
# Arrange # Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first() org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config) data = get_org_files(org_config)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.DEBUG): with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert # Assert
assert "Created 4 new embeddings" in caplog.text assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
assert "Created 6 new embeddings" in caplog.text
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -122,8 +129,8 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
final_logs = caplog.text final_logs = caplog.text
# Assert # Assert
assert "Deleting all embeddings for file type org" in initial_logs assert "Deleting all entries for file type org" in initial_logs
assert "Deleting all embeddings for file type org" not in final_logs assert "Deleting all entries for file type org" not in final_logs
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -135,7 +142,6 @@ async def test_text_search(search_config: SearchConfig):
default_user = await KhojUser.objects.acreate( default_user = await KhojUser.objects.acreate(
username="test_user", password="test_password", email="test@example.com" username="test_user", password="test_password", email="test@example.com"
) )
# Arrange
org_config = await LocalOrgConfig.objects.acreate( org_config = await LocalOrgConfig.objects.acreate(
input_files=None, input_files=None,
input_filter=["tests/data/org/*.org"], input_filter=["tests/data/org/*.org"],
@@ -159,13 +165,12 @@ async def test_text_search(search_config: SearchConfig):
# Act # Act
hits = await text_search.query(default_user, query) hits = await text_search.query(default_user, query)
# Assert
results = text_search.collate_results(hits) results = text_search.collate_results(hits)
results = sorted(results, key=lambda x: float(x.score))[:1] results = sorted(results, key=lambda x: float(x.score))[:1]
# search results should contain "git clone" entry
# Assert
search_result = results[0].entry search_result = results[0].entry
assert "git clone" in search_result assert "git clone" in search_result, 'search result did not contain "git clone" entry'
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -188,8 +193,9 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user) text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
# Assert # Assert
# verify newly added org-mode entry is split by max tokens assert (
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message "Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens"
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -245,8 +251,9 @@ conda activate khoj
) )
# Assert # Assert
# verify newly added org-mode entry is split by max tokens assert (
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message "Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens"
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@@ -256,27 +263,29 @@ def test_regenerate_index_with_new_entry(
): ):
# Arrange # Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first() org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config) initial_data = get_org_files(org_config)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
# append org-mode entry to first org input file in config # append org-mode entry to first org input file in config
org_config.input_files = [f"{new_org_file}"] org_config.input_files = [f"{new_org_file}"]
with open(new_org_file, "w") as f: with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
data = get_org_files(org_config) final_data = get_org_files(org_config)
# Act # Act
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# regenerate notes jsonl, model embeddings and model to include entry from new file # regenerate notes jsonl, model embeddings and model to include entry from new file
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user) text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
final_logs = caplog.text
# Assert # Assert
assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
assert "Deleted 10 entries. Created 11 new entries for user " in final_logs
verify_embeddings(11, default_user) verify_embeddings(11, default_user)
@@ -311,8 +320,8 @@ def test_update_index_with_duplicate_entries_in_stable_order(
# Assert # Assert
# verify only 1 entry added even if there are multiple duplicate entries # verify only 1 entry added even if there are multiple duplicate entries
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in initial_logs assert "Deleted 3 entries. Created 1 new entries for user " in initial_logs
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in final_logs assert "Deleted 0 entries. Created 0 new entries for user " in final_logs
verify_embeddings(1, default_user) verify_embeddings(1, default_user)
@@ -327,29 +336,29 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n" new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
with open(new_file_to_index, "w") as f: with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}{new_entry} -- Tatooine") f.write(f"{new_entry}{new_entry} -- Tatooine")
data = get_org_files(org_config_with_only_new_file) initial_data = get_org_files(org_config_with_only_new_file)
# load embeddings, entries, notes model after adding new org file with 2 entries
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# update embeddings, entries, notes model after removing an entry from the org file # update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f: with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}") f.write(f"{new_entry}")
data = get_org_files(org_config_with_only_new_file) final_data = get_org_files(org_config_with_only_new_file)
# Act # Act
# load embeddings, entries, notes model after adding new org file with 2 entries
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user) text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
final_logs = caplog.text final_logs = caplog.text
# Assert # Assert
# verify only 1 entry added even if there are multiple duplicate entries # verify only 1 entry added even if there are multiple duplicate entries
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in initial_logs assert "Deleted 3 entries. Created 2 new entries for user " in initial_logs
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in final_logs assert "Deleted 1 entries. Created 0 new entries for user " in final_logs
verify_embeddings(1, default_user) verify_embeddings(1, default_user)
@@ -379,9 +388,8 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
final_logs = caplog.text final_logs = caplog.text
# Assert # Assert
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in initial_logs assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in final_logs assert "Deleted 0 entries. Created 1 new entries for user " in final_logs
verify_embeddings(11, default_user) verify_embeddings(11, default_user)
@@ -390,6 +398,7 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser): def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
# Arrange # Arrange
github_config = GithubConfig.objects.filter(user=default_user).first() github_config = GithubConfig.objects.filter(user=default_user).first()
# Act # Act
# Regenerate github embeddings to test asymmetric setup without caching # Regenerate github embeddings to test asymmetric setup without caching
text_search.setup( text_search.setup(