mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Improve Indexing Text Entries (#535)
Major - Ensure search results logic consistent across migration to DB, multi-user - Manually verified search results for sample queries look the same across migration - Flatten indexing code for better indexing progress tracking and code readability Minor -a4f407fTest memory leak on MPS device when generating vector embeddings -ef24485Improve Khoj with DB setup instructions in the Django app readme (for now) -f212cc7Arrange remaining text search tests in arrange, act, assert order -022017dFix text search tests to test updated indexing log messages
This commit is contained in:
@@ -93,6 +93,7 @@ test = [
|
|||||||
"factory-boy >= 3.2.1",
|
"factory-boy >= 3.2.1",
|
||||||
"trio >= 0.22.0",
|
"trio >= 0.22.0",
|
||||||
"pytest-xdist",
|
"pytest-xdist",
|
||||||
|
"psutil >= 5.8.0",
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"khoj-assistant[test]",
|
"khoj-assistant[test]",
|
||||||
|
|||||||
@@ -17,16 +17,26 @@ docker-compose up
|
|||||||
|
|
||||||
## Setup (Local)
|
## Setup (Local)
|
||||||
|
|
||||||
### Install dependencies
|
### Install Postgres (with PgVector)
|
||||||
|
|
||||||
|
#### MacOS
|
||||||
|
- Install the [Postgres.app](https://postgresapp.com/).
|
||||||
|
|
||||||
|
#### Debian, Ubuntu
|
||||||
|
From [official instructions](https://wiki.postgresql.org/wiki/Apt)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -e '.[dev]'
|
sudo apt install -y postgresql-common
|
||||||
|
sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh
|
||||||
|
sudo apt install postgres-16 postgresql-16-pgvector
|
||||||
```
|
```
|
||||||
|
|
||||||
### Setup the database
|
#### Windows
|
||||||
|
- Use the [recommended installer](https://www.postgresql.org/download/windows/)
|
||||||
|
|
||||||
1. Ensure you have Postgres installed. For MacOS, you can use [Postgres.app](https://postgresapp.com/).
|
#### From Source
|
||||||
2. If you're not using Postgres.app, you may have to install the pgvector extension manually. You can find the instructions [here](https://github.com/pgvector/pgvector#installation). If you're using Postgres.app, you can skip this step. Reproduced instructions below for convenience.
|
1. Follow instructions to [Install Postgres](https://www.postgresql.org/download/)
|
||||||
|
2. Follow instructions to [Install PgVector](https://github.com/pgvector/pgvector#installation) in case you need to manually install it. Reproduced instructions below for convenience.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd /tmp
|
cd /tmp
|
||||||
@@ -35,32 +45,50 @@ cd pgvector
|
|||||||
make
|
make
|
||||||
make install # may need sudo
|
make install # may need sudo
|
||||||
```
|
```
|
||||||
3. Create a database
|
|
||||||
|
|
||||||
### Create the khoj database
|
### Create the Khoj database
|
||||||
|
|
||||||
|
#### MacOS
|
||||||
```bash
|
```bash
|
||||||
createdb khoj -U postgres
|
createdb khoj -U postgres
|
||||||
```
|
```
|
||||||
|
|
||||||
### Make migrations
|
#### Debian, Ubuntu
|
||||||
|
```bash
|
||||||
|
sudo -u postgres createdb khoj
|
||||||
|
```
|
||||||
|
|
||||||
This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted).
|
- [Optional] To set default postgres user's password
|
||||||
|
- Execute `ALTER USER postgres PASSWORD 'my_secure_password';` using `psql`
|
||||||
|
- Run `export $POSTGRES_PASSWORD=my_secure_password` in your terminal for Khoj to use it later
|
||||||
|
|
||||||
|
### Install Khoj
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e '.[dev]'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Make Khoj DB migrations
|
||||||
|
|
||||||
|
This command will create the migrations for the database app. This command should be run whenever a new db model is added to the database app or an existing db model is modified (updated or deleted).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 src/manage.py makemigrations
|
python3 src/manage.py makemigrations
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run migrations
|
### Run Khoj DB migrations
|
||||||
|
|
||||||
This command will run any pending migrations in your application.
|
This command will run any pending migrations in your application.
|
||||||
```bash
|
```bash
|
||||||
python3 src/manage.py migrate
|
python3 src/manage.py migrate
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run the server
|
### Start Khoj Server
|
||||||
|
|
||||||
While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend.
|
While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend.
|
||||||
|
|
||||||
|
*Note: Anonymous mode bypasses authentication for local, single-user usage.*
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 src/khoj/main.py
|
python3 src/khoj/main.py --anonymous-mode
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import secrets
|
|
||||||
from typing import Type, TypeVar, List
|
from typing import Type, TypeVar, List
|
||||||
from datetime import date
|
from datetime import date
|
||||||
import secrets
|
import secrets
|
||||||
@@ -36,9 +35,6 @@ from database.models import (
|
|||||||
OfflineChatProcessorConversationConfig,
|
OfflineChatProcessorConversationConfig,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import generate_random_name
|
from khoj.utils.helpers import generate_random_name
|
||||||
from khoj.utils.rawconfig import (
|
|
||||||
ConversationProcessorConfig as UserConversationProcessorConfig,
|
|
||||||
)
|
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||||
from sentence_transformers import CrossEncoder
|
|
||||||
|
|
||||||
from khoj.utils.helpers import get_device
|
from khoj.utils.helpers import get_device
|
||||||
from khoj.utils.rawconfig import SearchResponse
|
from khoj.utils.rawconfig import SearchResponse
|
||||||
@@ -9,18 +8,16 @@ from khoj.utils.rawconfig import SearchResponse
|
|||||||
|
|
||||||
class EmbeddingsModel:
|
class EmbeddingsModel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
self.encode_kwargs = {"normalize_embeddings": True}
|
||||||
|
self.model_kwargs = {"device": get_device()}
|
||||||
self.model_name = "thenlper/gte-small"
|
self.model_name = "thenlper/gte-small"
|
||||||
encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True}
|
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||||
model_kwargs = {"device": get_device()}
|
|
||||||
self.embeddings_model = HuggingFaceEmbeddings(
|
|
||||||
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def embed_query(self, query):
|
def embed_query(self, query):
|
||||||
return self.embeddings_model.embed_query(query)
|
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||||
|
|
||||||
def embed_documents(self, docs):
|
def embed_documents(self, docs):
|
||||||
return self.embeddings_model.embed_documents(docs)
|
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||||
|
|
||||||
|
|
||||||
class CrossEncoderModel:
|
class CrossEncoderModel:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class OrgToEntries(TextToEntries):
|
|||||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
# Extract required fields from config
|
# Extract required fields from config
|
||||||
index_heading_entries = True
|
index_heading_entries = False
|
||||||
|
|
||||||
if not full_corpus:
|
if not full_corpus:
|
||||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from itertools import repeat
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Callable, List, Tuple, Set, Any
|
from typing import Callable, List, Tuple, Set, Any
|
||||||
from khoj.utils.helpers import timer, batcher
|
from khoj.utils.helpers import is_none_or_empty, timer, batcher
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
@@ -83,92 +84,88 @@ class TextToEntries(ABC):
|
|||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
regenerate: bool = False,
|
regenerate: bool = False,
|
||||||
):
|
):
|
||||||
with timer("Construct current entry hashes", logger):
|
with timer("Constructed current entry hashes in", logger):
|
||||||
hashes_by_file = dict[str, set[str]]()
|
hashes_by_file = dict[str, set[str]]()
|
||||||
current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries))
|
current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries))
|
||||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||||
for entry in tqdm(current_entries, desc="Hashing Entries"):
|
for entry in tqdm(current_entries, desc="Hashing Entries"):
|
||||||
hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry))
|
hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry))
|
||||||
|
|
||||||
num_deleted_embeddings = 0
|
num_deleted_entries = 0
|
||||||
with timer("Preparing dataset for regeneration", logger):
|
if regenerate:
|
||||||
if regenerate:
|
with timer("Prepared dataset for regeneration in", logger):
|
||||||
logger.debug(f"Deleting all embeddings for file type {file_type}")
|
logger.debug(f"Deleting all entries for file type {file_type}")
|
||||||
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type)
|
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
|
||||||
|
|
||||||
num_new_embeddings = 0
|
hashes_to_process = set()
|
||||||
with timer("Identify hashes for adding new entries", logger):
|
with timer("Identified entries to add to database in", logger):
|
||||||
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
|
for file in tqdm(hashes_by_file, desc="Identify new entries"):
|
||||||
hashes_for_file = hashes_by_file[file]
|
hashes_for_file = hashes_by_file[file]
|
||||||
hashes_to_process = set()
|
|
||||||
existing_entries = DbEntry.objects.filter(
|
existing_entries = DbEntry.objects.filter(
|
||||||
user=user, hashed_value__in=hashes_for_file, file_type=file_type
|
user=user, hashed_value__in=hashes_for_file, file_type=file_type
|
||||||
)
|
)
|
||||||
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
||||||
hashes_to_process = hashes_for_file - existing_entry_hashes
|
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
||||||
|
|
||||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
embeddings = []
|
||||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
with timer("Generated embeddings for entries to add to database in", logger):
|
||||||
embeddings = self.embeddings_model.embed_documents(data_to_embed)
|
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||||
|
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||||
|
embeddings += self.embeddings_model.embed_documents(data_to_embed)
|
||||||
|
|
||||||
with timer("Update the database with new vector embeddings", logger):
|
added_entries: list[DbEntry] = []
|
||||||
num_items = len(hashes_to_process)
|
with timer("Added entries to database in", logger):
|
||||||
assert num_items == len(embeddings)
|
num_items = len(hashes_to_process)
|
||||||
batch_size = min(200, num_items)
|
assert num_items == len(embeddings)
|
||||||
entry_batches = zip(hashes_to_process, embeddings)
|
batch_size = min(200, num_items)
|
||||||
|
entry_batches = zip(hashes_to_process, embeddings)
|
||||||
|
|
||||||
for entry_batch in tqdm(
|
for entry_batch in tqdm(batcher(entry_batches, batch_size), desc="Add entries to database"):
|
||||||
batcher(entry_batches, batch_size), desc="Processing embeddings in batches"
|
batch_embeddings_to_create = []
|
||||||
):
|
for entry_hash, new_entry in entry_batch:
|
||||||
batch_embeddings_to_create = []
|
entry = hash_to_current_entries[entry_hash]
|
||||||
for entry_hash, new_entry in entry_batch:
|
batch_embeddings_to_create.append(
|
||||||
entry = hash_to_current_entries[entry_hash]
|
DbEntry(
|
||||||
batch_embeddings_to_create.append(
|
user=user,
|
||||||
DbEntry(
|
embeddings=new_entry,
|
||||||
user=user,
|
raw=entry.raw,
|
||||||
embeddings=new_entry,
|
compiled=entry.compiled,
|
||||||
raw=entry.raw,
|
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
||||||
compiled=entry.compiled,
|
file_path=entry.file,
|
||||||
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
file_type=file_type,
|
||||||
file_path=entry.file,
|
hashed_value=entry_hash,
|
||||||
file_type=file_type,
|
corpus_id=entry.corpus_id,
|
||||||
hashed_value=entry_hash,
|
)
|
||||||
corpus_id=entry.corpus_id,
|
)
|
||||||
)
|
added_entries += DbEntry.objects.bulk_create(batch_embeddings_to_create)
|
||||||
)
|
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
|
||||||
new_entries = DbEntry.objects.bulk_create(batch_embeddings_to_create)
|
|
||||||
logger.debug(f"Created {len(new_entries)} new embeddings")
|
|
||||||
num_new_embeddings += len(new_entries)
|
|
||||||
|
|
||||||
dates_to_create = []
|
new_dates = []
|
||||||
with timer("Create new date associations for new embeddings", logger):
|
with timer("Indexed dates from added entries in", logger):
|
||||||
for new_entry in new_entries:
|
for added_entry in added_entries:
|
||||||
dates = self.date_filter.extract_dates(new_entry.raw)
|
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
|
||||||
for date in dates:
|
dates_to_create = [
|
||||||
dates_to_create.append(
|
EntryDates(date=date, entry=added_entry)
|
||||||
EntryDates(
|
for date, added_entry in dates_in_entries
|
||||||
date=date,
|
if not is_none_or_empty(date)
|
||||||
entry=new_entry,
|
]
|
||||||
)
|
new_dates += EntryDates.objects.bulk_create(dates_to_create)
|
||||||
)
|
logger.debug(f"Indexed {len(new_dates)} dates from added {file_type} entries")
|
||||||
new_dates = EntryDates.objects.bulk_create(dates_to_create)
|
|
||||||
if len(new_dates) > 0:
|
|
||||||
logger.debug(f"Created {len(new_dates)} new date entries")
|
|
||||||
|
|
||||||
with timer("Identify hashes for removed entries", logger):
|
with timer("Deleted entries identified by server from database in", logger):
|
||||||
for file in hashes_by_file:
|
for file in hashes_by_file:
|
||||||
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
|
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
|
||||||
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
|
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
|
||||||
num_deleted_embeddings += len(to_delete_entry_hashes)
|
num_deleted_entries += len(to_delete_entry_hashes)
|
||||||
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
||||||
|
|
||||||
with timer("Identify hashes for deleting entries", logger):
|
with timer("Deleted entries requested by clients from database in", logger):
|
||||||
if deletion_filenames is not None:
|
if deletion_filenames is not None:
|
||||||
for file_path in deletion_filenames:
|
for file_path in deletion_filenames:
|
||||||
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
||||||
num_deleted_embeddings += deleted_count
|
num_deleted_entries += deleted_count
|
||||||
|
|
||||||
return num_new_embeddings, num_deleted_embeddings
|
return len(added_entries), num_deleted_entries
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mark_entries_for_update(
|
def mark_entries_for_update(
|
||||||
|
|||||||
@@ -321,7 +321,6 @@ def load_content(
|
|||||||
content_index: Optional[ContentIndex],
|
content_index: Optional[ContentIndex],
|
||||||
search_models: SearchModels,
|
search_models: SearchModels,
|
||||||
):
|
):
|
||||||
logger.info(f"Loading content from existing embeddings...")
|
|
||||||
if content_config is None:
|
if content_config is None:
|
||||||
logger.warning("🚨 No Content configuration available.")
|
logger.warning("🚨 No Content configuration available.")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ def setup(
|
|||||||
file_names = [file_name for file_name in files]
|
file_names = [file_name for file_name in files]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
|
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,8 @@ def cli(args=None):
|
|||||||
|
|
||||||
args, remaining_args = parser.parse_known_args(args)
|
args, remaining_args = parser.parse_known_args(args)
|
||||||
|
|
||||||
logger.debug(f"Ignoring unknown commandline args: {remaining_args}")
|
if len(remaining_args) > 0:
|
||||||
|
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
|
||||||
|
|
||||||
# Set default values for arguments
|
# Set default values for arguments
|
||||||
args.chat_on_gpu = not args.disable_chat_on_gpu
|
args.chat_on_gpu = not args.disable_chat_on_gpu
|
||||||
|
|||||||
@@ -1,3 +1,14 @@
|
|||||||
|
# Standard Packages
|
||||||
|
import numpy as np
|
||||||
|
import psutil
|
||||||
|
from scipy.stats import linregress
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from khoj.processor.embeddings import EmbeddingsModel
|
||||||
from khoj.utils import helpers
|
from khoj.utils import helpers
|
||||||
|
|
||||||
|
|
||||||
@@ -44,3 +55,29 @@ def test_lru_cache():
|
|||||||
cache["b"] # accessing 'b' makes it the most recently used item
|
cache["b"] # accessing 'b' makes it the most recently used item
|
||||||
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
|
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
|
||||||
assert cache == {"b": 2, "d": 4}
|
assert cache == {"b": 2, "d": 4}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Memory leak exists on GPU, MPS devices")
|
||||||
|
def test_encode_docs_memory_leak():
|
||||||
|
# Arrange
|
||||||
|
iterations = 50
|
||||||
|
batch_size = 20
|
||||||
|
embeddings_model = EmbeddingsModel()
|
||||||
|
memory_usage_trend = []
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Encode random strings repeatedly and record memory usage trend
|
||||||
|
for iteration in range(iterations):
|
||||||
|
random_docs = [" ".join(secrets.token_hex(5) for _ in range(10)) for _ in range(batch_size)]
|
||||||
|
a = [embeddings_model.embed_documents(random_docs)]
|
||||||
|
memory_usage_trend += [psutil.Process().memory_info().rss / (1024 * 1024)]
|
||||||
|
print(f"{iteration:02d}, {memory_usage_trend[-1]:.2f}", flush=True)
|
||||||
|
|
||||||
|
# Calculate slope of line fitting memory usage history
|
||||||
|
memory_usage_trend = np.array(memory_usage_trend)
|
||||||
|
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# If slope is positive memory utilization is increasing
|
||||||
|
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
|
||||||
|
assert slope < 2, f"Memory usage increasing at ~{slope:.2f} MB per iteration"
|
||||||
|
|||||||
@@ -48,10 +48,11 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
|
|||||||
user=default_user,
|
user=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
org_files = collect_files(user=default_user)["org"]
|
org_files = collect_files(user=default_user)["org"]
|
||||||
|
|
||||||
# Act
|
# Assert
|
||||||
# should not raise IsADirectoryError and return orgfile
|
# should return orgfile and not raise IsADirectoryError
|
||||||
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
|
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
|
||||||
|
|
||||||
|
|
||||||
@@ -62,12 +63,14 @@ def test_text_search_setup_with_empty_file_raises_error(
|
|||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
data = get_org_files(org_config_with_only_new_file)
|
data = get_org_files(org_config_with_only_new_file)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Generate notes embeddings during asymmetric setup
|
# Generate notes embeddings during asymmetric setup
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||||
|
|
||||||
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
# Assert
|
||||||
|
assert "Deleted 3 entries. Created 0 new entries for user " in caplog.records[-1].message
|
||||||
verify_embeddings(0, default_user)
|
verify_embeddings(0, default_user)
|
||||||
|
|
||||||
|
|
||||||
@@ -79,12 +82,15 @@ def test_text_indexer_deletes_embedding_before_regenerate(
|
|||||||
# Arrange
|
# Arrange
|
||||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
data = get_org_files(org_config)
|
data = get_org_files(org_config)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Generate notes embeddings during asymmetric setup
|
||||||
with caplog.at_level(logging.DEBUG):
|
with caplog.at_level(logging.DEBUG):
|
||||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "Deleting all embeddings for file type org" in caplog.text
|
assert "Deleting all entries for file type org" in caplog.text
|
||||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@@ -93,13 +99,14 @@ def test_text_search_setup_batch_processes(content_config: ContentConfig, defaul
|
|||||||
# Arrange
|
# Arrange
|
||||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
data = get_org_files(org_config)
|
data = get_org_files(org_config)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
# Generate notes embeddings during asymmetric setup
|
||||||
with caplog.at_level(logging.DEBUG):
|
with caplog.at_level(logging.DEBUG):
|
||||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "Created 4 new embeddings" in caplog.text
|
assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
|
||||||
assert "Created 6 new embeddings" in caplog.text
|
|
||||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@@ -122,8 +129,8 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
|
|||||||
final_logs = caplog.text
|
final_logs = caplog.text
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "Deleting all embeddings for file type org" in initial_logs
|
assert "Deleting all entries for file type org" in initial_logs
|
||||||
assert "Deleting all embeddings for file type org" not in final_logs
|
assert "Deleting all entries for file type org" not in final_logs
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@@ -135,7 +142,6 @@ async def test_text_search(search_config: SearchConfig):
|
|||||||
default_user = await KhojUser.objects.acreate(
|
default_user = await KhojUser.objects.acreate(
|
||||||
username="test_user", password="test_password", email="test@example.com"
|
username="test_user", password="test_password", email="test@example.com"
|
||||||
)
|
)
|
||||||
# Arrange
|
|
||||||
org_config = await LocalOrgConfig.objects.acreate(
|
org_config = await LocalOrgConfig.objects.acreate(
|
||||||
input_files=None,
|
input_files=None,
|
||||||
input_filter=["tests/data/org/*.org"],
|
input_filter=["tests/data/org/*.org"],
|
||||||
@@ -159,13 +165,12 @@ async def test_text_search(search_config: SearchConfig):
|
|||||||
|
|
||||||
# Act
|
# Act
|
||||||
hits = await text_search.query(default_user, query)
|
hits = await text_search.query(default_user, query)
|
||||||
|
|
||||||
# Assert
|
|
||||||
results = text_search.collate_results(hits)
|
results = text_search.collate_results(hits)
|
||||||
results = sorted(results, key=lambda x: float(x.score))[:1]
|
results = sorted(results, key=lambda x: float(x.score))[:1]
|
||||||
# search results should contain "git clone" entry
|
|
||||||
|
# Assert
|
||||||
search_result = results[0].entry
|
search_result = results[0].entry
|
||||||
assert "git clone" in search_result
|
assert "git clone" in search_result, 'search result did not contain "git clone" entry'
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@@ -188,8 +193,9 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
|
|||||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify newly added org-mode entry is split by max tokens
|
assert (
|
||||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
|
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||||
|
), "new entry not split by max tokens"
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@@ -245,8 +251,9 @@ conda activate khoj
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify newly added org-mode entry is split by max tokens
|
assert (
|
||||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in caplog.records[-1].message
|
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||||
|
), "new entry not split by max tokens"
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@@ -256,27 +263,29 @@ def test_regenerate_index_with_new_entry(
|
|||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||||
data = get_org_files(org_config)
|
initial_data = get_org_files(org_config)
|
||||||
|
|
||||||
with caplog.at_level(logging.INFO):
|
|
||||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
|
||||||
|
|
||||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[-1].message
|
|
||||||
|
|
||||||
# append org-mode entry to first org input file in config
|
# append org-mode entry to first org input file in config
|
||||||
org_config.input_files = [f"{new_org_file}"]
|
org_config.input_files = [f"{new_org_file}"]
|
||||||
with open(new_org_file, "w") as f:
|
with open(new_org_file, "w") as f:
|
||||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||||
|
|
||||||
data = get_org_files(org_config)
|
final_data = get_org_files(org_config)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
with caplog.at_level(logging.INFO):
|
||||||
|
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||||
|
initial_logs = caplog.text
|
||||||
|
caplog.clear() # Clear logs
|
||||||
|
|
||||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
|
||||||
|
final_logs = caplog.text
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message
|
assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
|
||||||
|
assert "Deleted 10 entries. Created 11 new entries for user " in final_logs
|
||||||
verify_embeddings(11, default_user)
|
verify_embeddings(11, default_user)
|
||||||
|
|
||||||
|
|
||||||
@@ -311,8 +320,8 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
|||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify only 1 entry added even if there are multiple duplicate entries
|
# verify only 1 entry added even if there are multiple duplicate entries
|
||||||
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in initial_logs
|
assert "Deleted 3 entries. Created 1 new entries for user " in initial_logs
|
||||||
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in final_logs
|
assert "Deleted 0 entries. Created 0 new entries for user " in final_logs
|
||||||
|
|
||||||
verify_embeddings(1, default_user)
|
verify_embeddings(1, default_user)
|
||||||
|
|
||||||
@@ -327,29 +336,29 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
|
|||||||
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||||
with open(new_file_to_index, "w") as f:
|
with open(new_file_to_index, "w") as f:
|
||||||
f.write(f"{new_entry}{new_entry} -- Tatooine")
|
f.write(f"{new_entry}{new_entry} -- Tatooine")
|
||||||
data = get_org_files(org_config_with_only_new_file)
|
initial_data = get_org_files(org_config_with_only_new_file)
|
||||||
|
|
||||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
|
||||||
with caplog.at_level(logging.INFO):
|
|
||||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
|
||||||
initial_logs = caplog.text
|
|
||||||
caplog.clear() # Clear logs
|
|
||||||
|
|
||||||
# update embeddings, entries, notes model after removing an entry from the org file
|
# update embeddings, entries, notes model after removing an entry from the org file
|
||||||
with open(new_file_to_index, "w") as f:
|
with open(new_file_to_index, "w") as f:
|
||||||
f.write(f"{new_entry}")
|
f.write(f"{new_entry}")
|
||||||
|
|
||||||
data = get_org_files(org_config_with_only_new_file)
|
final_data = get_org_files(org_config_with_only_new_file)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||||
with caplog.at_level(logging.INFO):
|
with caplog.at_level(logging.INFO):
|
||||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||||
|
initial_logs = caplog.text
|
||||||
|
caplog.clear() # Clear logs
|
||||||
|
|
||||||
|
with caplog.at_level(logging.INFO):
|
||||||
|
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
|
||||||
final_logs = caplog.text
|
final_logs = caplog.text
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# verify only 1 entry added even if there are multiple duplicate entries
|
# verify only 1 entry added even if there are multiple duplicate entries
|
||||||
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in initial_logs
|
assert "Deleted 3 entries. Created 2 new entries for user " in initial_logs
|
||||||
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in final_logs
|
assert "Deleted 1 entries. Created 0 new entries for user " in final_logs
|
||||||
|
|
||||||
verify_embeddings(1, default_user)
|
verify_embeddings(1, default_user)
|
||||||
|
|
||||||
@@ -379,9 +388,8 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
|||||||
final_logs = caplog.text
|
final_logs = caplog.text
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in initial_logs
|
assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
|
||||||
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in final_logs
|
assert "Deleted 0 entries. Created 1 new entries for user " in final_logs
|
||||||
|
|
||||||
verify_embeddings(11, default_user)
|
verify_embeddings(11, default_user)
|
||||||
|
|
||||||
|
|
||||||
@@ -390,6 +398,7 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
|||||||
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
|
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
github_config = GithubConfig.objects.filter(user=default_user).first()
|
github_config = GithubConfig.objects.filter(user=default_user).first()
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
# Regenerate github embeddings to test asymmetric setup without caching
|
# Regenerate github embeddings to test asymmetric setup without caching
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
|
|||||||
Reference in New Issue
Block a user