Improve Indexing Text Entries (#535)

Major
- Ensure search results logic consistent across migration to DB, multi-user
- Manually verified search results for sample queries look the same across migration
 - Flatten indexing code for better indexing progress tracking and code readability

Minor
- a4f407f Test memory leak on MPS device when generating vector embeddings
- ef24485 Improve Khoj with DB setup instructions in the Django app readme (for now)
- f212cc7 Arrange remaining text search tests in arrange, act, assert order
- 022017d Fix text search tests to test updated indexing log messages
This commit is contained in:
Debanjum
2023-11-06 16:01:53 -08:00
committed by GitHub
11 changed files with 199 additions and 134 deletions

View File

@@ -17,16 +17,26 @@ docker-compose up
## Setup (Local)
### Install dependencies
### Install Postgres (with PgVector)
#### MacOS
- Install the [Postgres.app](https://postgresapp.com/).
#### Debian, Ubuntu
From [official instructions](https://wiki.postgresql.org/wiki/Apt)
```bash
pip install -e '.[dev]'
sudo apt install -y postgresql-common
sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh
sudo apt install postgres-16 postgresql-16-pgvector
```
### Setup the database
#### Windows
- Use the [recommended installer](https://www.postgresql.org/download/windows/)
1. Ensure you have Postgres installed. For MacOS, you can use [Postgres.app](https://postgresapp.com/).
2. If you're not using Postgres.app, you may have to install the pgvector extension manually. You can find the instructions [here](https://github.com/pgvector/pgvector#installation). If you're using Postgres.app, you can skip this step. Reproduced instructions below for convenience.
#### From Source
1. Follow instructions to [Install Postgres](https://www.postgresql.org/download/)
2. Follow instructions to [Install PgVector](https://github.com/pgvector/pgvector#installation) in case you need to manually install it. Reproduced instructions below for convenience.
```bash
cd /tmp
@@ -35,32 +45,50 @@ cd pgvector
make
make install # may need sudo
```
3. Create a database
### Create the khoj database
### Create the Khoj database
#### MacOS
```bash
createdb khoj -U postgres
```
### Make migrations
#### Debian, Ubuntu
```bash
sudo -u postgres createdb khoj
```
This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted).
- [Optional] To set default postgres user's password
- Execute `ALTER USER postgres PASSWORD 'my_secure_password';` using `psql`
- Run `export $POSTGRES_PASSWORD=my_secure_password` in your terminal for Khoj to use it later
### Install Khoj
```bash
pip install -e '.[dev]'
```
### Make Khoj DB migrations
This command will create the migrations for the database app. This command should be run whenever a new db model is added to the database app or an existing db model is modified (updated or deleted).
```bash
python3 src/manage.py makemigrations
```
### Run migrations
### Run Khoj DB migrations
This command will run any pending migrations in your application.
```bash
python3 src/manage.py migrate
```
### Run the server
### Start Khoj Server
While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend.
*Note: Anonymous mode bypasses authentication for local, single-user usage.*
```bash
python3 src/khoj/main.py
python3 src/khoj/main.py --anonymous-mode
```

View File

@@ -1,4 +1,3 @@
import secrets
from typing import Type, TypeVar, List
from datetime import date
import secrets
@@ -36,9 +35,6 @@ from database.models import (
OfflineChatProcessorConversationConfig,
)
from khoj.utils.helpers import generate_random_name
from khoj.utils.rawconfig import (
ConversationProcessorConfig as UserConversationProcessorConfig,
)
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.date_filter import DateFilter

View File

@@ -1,7 +1,6 @@
from typing import List
from langchain.embeddings import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
from sentence_transformers import SentenceTransformer, CrossEncoder
from khoj.utils.helpers import get_device
from khoj.utils.rawconfig import SearchResponse
@@ -9,18 +8,16 @@ from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel:
def __init__(self):
self.encode_kwargs = {"normalize_embeddings": True}
self.model_kwargs = {"device": get_device()}
self.model_name = "thenlper/gte-small"
encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True}
model_kwargs = {"device": get_device()}
self.embeddings_model = HuggingFaceEmbeddings(
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
)
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def embed_query(self, query):
return self.embeddings_model.embed_query(query)
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
def embed_documents(self, docs):
return self.embeddings_model.embed_documents(docs)
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
class CrossEncoderModel:

View File

@@ -24,7 +24,7 @@ class OrgToEntries(TextToEntries):
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
index_heading_entries = True
index_heading_entries = False
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == ""])

View File

@@ -1,11 +1,12 @@
# Standard Packages
from abc import ABC, abstractmethod
import hashlib
from itertools import repeat
import logging
import uuid
from tqdm import tqdm
from typing import Callable, List, Tuple, Set, Any
from khoj.utils.helpers import timer, batcher
from khoj.utils.helpers import is_none_or_empty, timer, batcher
# Internal Packages
@@ -83,92 +84,88 @@ class TextToEntries(ABC):
user: KhojUser = None,
regenerate: bool = False,
):
with timer("Construct current entry hashes", logger):
with timer("Constructed current entry hashes in", logger):
hashes_by_file = dict[str, set[str]]()
current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries))
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
for entry in tqdm(current_entries, desc="Hashing Entries"):
hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry))
num_deleted_embeddings = 0
with timer("Preparing dataset for regeneration", logger):
if regenerate:
logger.debug(f"Deleting all embeddings for file type {file_type}")
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type)
num_deleted_entries = 0
if regenerate:
with timer("Prepared dataset for regeneration in", logger):
logger.debug(f"Deleting all entries for file type {file_type}")
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
num_new_embeddings = 0
with timer("Identify hashes for adding new entries", logger):
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
hashes_to_process = set()
with timer("Identified entries to add to database in", logger):
for file in tqdm(hashes_by_file, desc="Identify new entries"):
hashes_for_file = hashes_by_file[file]
hashes_to_process = set()
existing_entries = DbEntry.objects.filter(
user=user, hashed_value__in=hashes_for_file, file_type=file_type
)
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
hashes_to_process = hashes_for_file - existing_entry_hashes
hashes_to_process |= hashes_for_file - existing_entry_hashes
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings = self.embeddings_model.embed_documents(data_to_embed)
embeddings = []
with timer("Generated embeddings for entries to add to database in", logger):
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings += self.embeddings_model.embed_documents(data_to_embed)
with timer("Update the database with new vector embeddings", logger):
num_items = len(hashes_to_process)
assert num_items == len(embeddings)
batch_size = min(200, num_items)
entry_batches = zip(hashes_to_process, embeddings)
added_entries: list[DbEntry] = []
with timer("Added entries to database in", logger):
num_items = len(hashes_to_process)
assert num_items == len(embeddings)
batch_size = min(200, num_items)
entry_batches = zip(hashes_to_process, embeddings)
for entry_batch in tqdm(
batcher(entry_batches, batch_size), desc="Processing embeddings in batches"
):
batch_embeddings_to_create = []
for entry_hash, new_entry in entry_batch:
entry = hash_to_current_entries[entry_hash]
batch_embeddings_to_create.append(
DbEntry(
user=user,
embeddings=new_entry,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading[:1000], # Truncate to max chars of field allowed
file_path=entry.file,
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,
)
)
new_entries = DbEntry.objects.bulk_create(batch_embeddings_to_create)
logger.debug(f"Created {len(new_entries)} new embeddings")
num_new_embeddings += len(new_entries)
for entry_batch in tqdm(batcher(entry_batches, batch_size), desc="Add entries to database"):
batch_embeddings_to_create = []
for entry_hash, new_entry in entry_batch:
entry = hash_to_current_entries[entry_hash]
batch_embeddings_to_create.append(
DbEntry(
user=user,
embeddings=new_entry,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading[:1000], # Truncate to max chars of field allowed
file_path=entry.file,
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,
)
)
added_entries += DbEntry.objects.bulk_create(batch_embeddings_to_create)
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
dates_to_create = []
with timer("Create new date associations for new embeddings", logger):
for new_entry in new_entries:
dates = self.date_filter.extract_dates(new_entry.raw)
for date in dates:
dates_to_create.append(
EntryDates(
date=date,
entry=new_entry,
)
)
new_dates = EntryDates.objects.bulk_create(dates_to_create)
if len(new_dates) > 0:
logger.debug(f"Created {len(new_dates)} new date entries")
new_dates = []
with timer("Indexed dates from added entries in", logger):
for added_entry in added_entries:
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
dates_to_create = [
EntryDates(date=date, entry=added_entry)
for date, added_entry in dates_in_entries
if not is_none_or_empty(date)
]
new_dates += EntryDates.objects.bulk_create(dates_to_create)
logger.debug(f"Indexed {len(new_dates)} dates from added {file_type} entries")
with timer("Identify hashes for removed entries", logger):
with timer("Deleted entries identified by server from database in", logger):
for file in hashes_by_file:
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
num_deleted_embeddings += len(to_delete_entry_hashes)
num_deleted_entries += len(to_delete_entry_hashes)
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
with timer("Identify hashes for deleting entries", logger):
with timer("Deleted entries requested by clients from database in", logger):
if deletion_filenames is not None:
for file_path in deletion_filenames:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_embeddings += deleted_count
num_deleted_entries += deleted_count
return num_new_embeddings, num_deleted_embeddings
return len(added_entries), num_deleted_entries
@staticmethod
def mark_entries_for_update(

View File

@@ -321,7 +321,6 @@ def load_content(
content_index: Optional[ContentIndex],
search_models: SearchModels,
):
logger.info(f"Loading content from existing embeddings...")
if content_config is None:
logger.warning("🚨 No Content configuration available.")
return None

View File

@@ -207,7 +207,7 @@ def setup(
file_names = [file_name for file_name in files]
logger.info(
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
)

View File

@@ -51,7 +51,8 @@ def cli(args=None):
args, remaining_args = parser.parse_known_args(args)
logger.debug(f"Ignoring unknown commandline args: {remaining_args}")
if len(remaining_args) > 0:
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
# Set default values for arguments
args.chat_on_gpu = not args.disable_chat_on_gpu