mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Improve Indexing Text Entries (#535)
Major - Ensure search results logic consistent across migration to DB, multi-user - Manually verified search results for sample queries look the same across migration - Flatten indexing code for better indexing progress tracking and code readability Minor -a4f407fTest memory leak on MPS device when generating vector embeddings -ef24485Improve Khoj with DB setup instructions in the Django app readme (for now) -f212cc7Arrange remaining text search tests in arrange, act, assert order -022017dFix text search tests to test updated indexing log messages
This commit is contained in:
@@ -17,16 +17,26 @@ docker-compose up
|
||||
|
||||
## Setup (Local)
|
||||
|
||||
### Install dependencies
|
||||
### Install Postgres (with PgVector)
|
||||
|
||||
#### MacOS
|
||||
- Install the [Postgres.app](https://postgresapp.com/).
|
||||
|
||||
#### Debian, Ubuntu
|
||||
From [official instructions](https://wiki.postgresql.org/wiki/Apt)
|
||||
|
||||
```bash
|
||||
pip install -e '.[dev]'
|
||||
sudo apt install -y postgresql-common
|
||||
sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh
|
||||
sudo apt install postgres-16 postgresql-16-pgvector
|
||||
```
|
||||
|
||||
### Setup the database
|
||||
#### Windows
|
||||
- Use the [recommended installer](https://www.postgresql.org/download/windows/)
|
||||
|
||||
1. Ensure you have Postgres installed. For MacOS, you can use [Postgres.app](https://postgresapp.com/).
|
||||
2. If you're not using Postgres.app, you may have to install the pgvector extension manually. You can find the instructions [here](https://github.com/pgvector/pgvector#installation). If you're using Postgres.app, you can skip this step. Reproduced instructions below for convenience.
|
||||
#### From Source
|
||||
1. Follow instructions to [Install Postgres](https://www.postgresql.org/download/)
|
||||
2. Follow instructions to [Install PgVector](https://github.com/pgvector/pgvector#installation) in case you need to manually install it. Reproduced instructions below for convenience.
|
||||
|
||||
```bash
|
||||
cd /tmp
|
||||
@@ -35,32 +45,50 @@ cd pgvector
|
||||
make
|
||||
make install # may need sudo
|
||||
```
|
||||
3. Create a database
|
||||
|
||||
### Create the khoj database
|
||||
### Create the Khoj database
|
||||
|
||||
#### MacOS
|
||||
```bash
|
||||
createdb khoj -U postgres
|
||||
```
|
||||
|
||||
### Make migrations
|
||||
#### Debian, Ubuntu
|
||||
```bash
|
||||
sudo -u postgres createdb khoj
|
||||
```
|
||||
|
||||
This command will create the migrations for the database app. This command should be run whenever a new model is added to the database app or an existing model is modified (updated or deleted).
|
||||
- [Optional] To set default postgres user's password
|
||||
- Execute `ALTER USER postgres PASSWORD 'my_secure_password';` using `psql`
|
||||
- Run `export $POSTGRES_PASSWORD=my_secure_password` in your terminal for Khoj to use it later
|
||||
|
||||
### Install Khoj
|
||||
|
||||
```bash
|
||||
pip install -e '.[dev]'
|
||||
```
|
||||
|
||||
### Make Khoj DB migrations
|
||||
|
||||
This command will create the migrations for the database app. This command should be run whenever a new db model is added to the database app or an existing db model is modified (updated or deleted).
|
||||
|
||||
```bash
|
||||
python3 src/manage.py makemigrations
|
||||
```
|
||||
|
||||
### Run migrations
|
||||
### Run Khoj DB migrations
|
||||
|
||||
This command will run any pending migrations in your application.
|
||||
```bash
|
||||
python3 src/manage.py migrate
|
||||
```
|
||||
|
||||
### Run the server
|
||||
### Start Khoj Server
|
||||
|
||||
While we're using Django for the ORM, we're still using the FastAPI server for the API. This command automatically scaffolds the Django application in the backend.
|
||||
|
||||
*Note: Anonymous mode bypasses authentication for local, single-user usage.*
|
||||
|
||||
```bash
|
||||
python3 src/khoj/main.py
|
||||
python3 src/khoj/main.py --anonymous-mode
|
||||
```
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import secrets
|
||||
from typing import Type, TypeVar, List
|
||||
from datetime import date
|
||||
import secrets
|
||||
@@ -36,9 +35,6 @@ from database.models import (
|
||||
OfflineChatProcessorConversationConfig,
|
||||
)
|
||||
from khoj.utils.helpers import generate_random_name
|
||||
from khoj.utils.rawconfig import (
|
||||
ConversationProcessorConfig as UserConversationProcessorConfig,
|
||||
)
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from sentence_transformers import CrossEncoder
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
|
||||
from khoj.utils.helpers import get_device
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
@@ -9,18 +8,16 @@ from khoj.utils.rawconfig import SearchResponse
|
||||
|
||||
class EmbeddingsModel:
|
||||
def __init__(self):
|
||||
self.encode_kwargs = {"normalize_embeddings": True}
|
||||
self.model_kwargs = {"device": get_device()}
|
||||
self.model_name = "thenlper/gte-small"
|
||||
encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True}
|
||||
model_kwargs = {"device": get_device()}
|
||||
self.embeddings_model = HuggingFaceEmbeddings(
|
||||
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
|
||||
)
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
|
||||
def embed_query(self, query):
|
||||
return self.embeddings_model.embed_query(query)
|
||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||
|
||||
def embed_documents(self, docs):
|
||||
return self.embeddings_model.embed_documents(docs)
|
||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||
|
||||
|
||||
class CrossEncoderModel:
|
||||
|
||||
@@ -24,7 +24,7 @@ class OrgToEntries(TextToEntries):
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
index_heading_entries = True
|
||||
index_heading_entries = False
|
||||
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
import hashlib
|
||||
from itertools import repeat
|
||||
import logging
|
||||
import uuid
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, List, Tuple, Set, Any
|
||||
from khoj.utils.helpers import timer, batcher
|
||||
from khoj.utils.helpers import is_none_or_empty, timer, batcher
|
||||
|
||||
|
||||
# Internal Packages
|
||||
@@ -83,92 +84,88 @@ class TextToEntries(ABC):
|
||||
user: KhojUser = None,
|
||||
regenerate: bool = False,
|
||||
):
|
||||
with timer("Construct current entry hashes", logger):
|
||||
with timer("Constructed current entry hashes in", logger):
|
||||
hashes_by_file = dict[str, set[str]]()
|
||||
current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries))
|
||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||
for entry in tqdm(current_entries, desc="Hashing Entries"):
|
||||
hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry))
|
||||
|
||||
num_deleted_embeddings = 0
|
||||
with timer("Preparing dataset for regeneration", logger):
|
||||
if regenerate:
|
||||
logger.debug(f"Deleting all embeddings for file type {file_type}")
|
||||
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type)
|
||||
num_deleted_entries = 0
|
||||
if regenerate:
|
||||
with timer("Prepared dataset for regeneration in", logger):
|
||||
logger.debug(f"Deleting all entries for file type {file_type}")
|
||||
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
|
||||
|
||||
num_new_embeddings = 0
|
||||
with timer("Identify hashes for adding new entries", logger):
|
||||
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
|
||||
hashes_to_process = set()
|
||||
with timer("Identified entries to add to database in", logger):
|
||||
for file in tqdm(hashes_by_file, desc="Identify new entries"):
|
||||
hashes_for_file = hashes_by_file[file]
|
||||
hashes_to_process = set()
|
||||
existing_entries = DbEntry.objects.filter(
|
||||
user=user, hashed_value__in=hashes_for_file, file_type=file_type
|
||||
)
|
||||
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
||||
hashes_to_process = hashes_for_file - existing_entry_hashes
|
||||
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
||||
|
||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||
embeddings = self.embeddings_model.embed_documents(data_to_embed)
|
||||
embeddings = []
|
||||
with timer("Generated embeddings for entries to add to database in", logger):
|
||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||
embeddings += self.embeddings_model.embed_documents(data_to_embed)
|
||||
|
||||
with timer("Update the database with new vector embeddings", logger):
|
||||
num_items = len(hashes_to_process)
|
||||
assert num_items == len(embeddings)
|
||||
batch_size = min(200, num_items)
|
||||
entry_batches = zip(hashes_to_process, embeddings)
|
||||
added_entries: list[DbEntry] = []
|
||||
with timer("Added entries to database in", logger):
|
||||
num_items = len(hashes_to_process)
|
||||
assert num_items == len(embeddings)
|
||||
batch_size = min(200, num_items)
|
||||
entry_batches = zip(hashes_to_process, embeddings)
|
||||
|
||||
for entry_batch in tqdm(
|
||||
batcher(entry_batches, batch_size), desc="Processing embeddings in batches"
|
||||
):
|
||||
batch_embeddings_to_create = []
|
||||
for entry_hash, new_entry in entry_batch:
|
||||
entry = hash_to_current_entries[entry_hash]
|
||||
batch_embeddings_to_create.append(
|
||||
DbEntry(
|
||||
user=user,
|
||||
embeddings=new_entry,
|
||||
raw=entry.raw,
|
||||
compiled=entry.compiled,
|
||||
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
||||
file_path=entry.file,
|
||||
file_type=file_type,
|
||||
hashed_value=entry_hash,
|
||||
corpus_id=entry.corpus_id,
|
||||
)
|
||||
)
|
||||
new_entries = DbEntry.objects.bulk_create(batch_embeddings_to_create)
|
||||
logger.debug(f"Created {len(new_entries)} new embeddings")
|
||||
num_new_embeddings += len(new_entries)
|
||||
for entry_batch in tqdm(batcher(entry_batches, batch_size), desc="Add entries to database"):
|
||||
batch_embeddings_to_create = []
|
||||
for entry_hash, new_entry in entry_batch:
|
||||
entry = hash_to_current_entries[entry_hash]
|
||||
batch_embeddings_to_create.append(
|
||||
DbEntry(
|
||||
user=user,
|
||||
embeddings=new_entry,
|
||||
raw=entry.raw,
|
||||
compiled=entry.compiled,
|
||||
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
||||
file_path=entry.file,
|
||||
file_type=file_type,
|
||||
hashed_value=entry_hash,
|
||||
corpus_id=entry.corpus_id,
|
||||
)
|
||||
)
|
||||
added_entries += DbEntry.objects.bulk_create(batch_embeddings_to_create)
|
||||
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
|
||||
|
||||
dates_to_create = []
|
||||
with timer("Create new date associations for new embeddings", logger):
|
||||
for new_entry in new_entries:
|
||||
dates = self.date_filter.extract_dates(new_entry.raw)
|
||||
for date in dates:
|
||||
dates_to_create.append(
|
||||
EntryDates(
|
||||
date=date,
|
||||
entry=new_entry,
|
||||
)
|
||||
)
|
||||
new_dates = EntryDates.objects.bulk_create(dates_to_create)
|
||||
if len(new_dates) > 0:
|
||||
logger.debug(f"Created {len(new_dates)} new date entries")
|
||||
new_dates = []
|
||||
with timer("Indexed dates from added entries in", logger):
|
||||
for added_entry in added_entries:
|
||||
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
|
||||
dates_to_create = [
|
||||
EntryDates(date=date, entry=added_entry)
|
||||
for date, added_entry in dates_in_entries
|
||||
if not is_none_or_empty(date)
|
||||
]
|
||||
new_dates += EntryDates.objects.bulk_create(dates_to_create)
|
||||
logger.debug(f"Indexed {len(new_dates)} dates from added {file_type} entries")
|
||||
|
||||
with timer("Identify hashes for removed entries", logger):
|
||||
with timer("Deleted entries identified by server from database in", logger):
|
||||
for file in hashes_by_file:
|
||||
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
|
||||
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
|
||||
num_deleted_embeddings += len(to_delete_entry_hashes)
|
||||
num_deleted_entries += len(to_delete_entry_hashes)
|
||||
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
||||
|
||||
with timer("Identify hashes for deleting entries", logger):
|
||||
with timer("Deleted entries requested by clients from database in", logger):
|
||||
if deletion_filenames is not None:
|
||||
for file_path in deletion_filenames:
|
||||
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
||||
num_deleted_embeddings += deleted_count
|
||||
num_deleted_entries += deleted_count
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
return len(added_entries), num_deleted_entries
|
||||
|
||||
@staticmethod
|
||||
def mark_entries_for_update(
|
||||
|
||||
@@ -321,7 +321,6 @@ def load_content(
|
||||
content_index: Optional[ContentIndex],
|
||||
search_models: SearchModels,
|
||||
):
|
||||
logger.info(f"Loading content from existing embeddings...")
|
||||
if content_config is None:
|
||||
logger.warning("🚨 No Content configuration available.")
|
||||
return None
|
||||
|
||||
@@ -207,7 +207,7 @@ def setup(
|
||||
file_names = [file_name for file_name in files]
|
||||
|
||||
logger.info(
|
||||
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
|
||||
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -51,7 +51,8 @@ def cli(args=None):
|
||||
|
||||
args, remaining_args = parser.parse_known_args(args)
|
||||
|
||||
logger.debug(f"Ignoring unknown commandline args: {remaining_args}")
|
||||
if len(remaining_args) > 0:
|
||||
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
|
||||
|
||||
# Set default values for arguments
|
||||
args.chat_on_gpu = not args.disable_chat_on_gpu
|
||||
|
||||
Reference in New Issue
Block a user