Flatten nested loops, improve progress reporting in text_to_jsonl indexer

Flatten the nested loops to improve visibilty into indexing progress

Reduce spurious logs, report the logs at aggregated level and update
the logging description text to improve indexing progress reporting
This commit is contained in:
Debanjum Singh Solanky
2023-11-04 04:55:51 -07:00
parent 12b5ef6540
commit dc9946fc03
4 changed files with 61 additions and 64 deletions

View File

@@ -1,11 +1,12 @@
# Standard Packages # Standard Packages
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import hashlib import hashlib
from itertools import repeat
import logging import logging
import uuid import uuid
from tqdm import tqdm from tqdm import tqdm
from typing import Callable, List, Tuple, Set, Any from typing import Callable, List, Tuple, Set, Any
from khoj.utils.helpers import timer, batcher from khoj.utils.helpers import is_none_or_empty, timer, batcher
# Internal Packages # Internal Packages
@@ -83,92 +84,88 @@ class TextToEntries(ABC):
user: KhojUser = None, user: KhojUser = None,
regenerate: bool = False, regenerate: bool = False,
): ):
with timer("Construct current entry hashes", logger): with timer("Constructed current entry hashes in", logger):
hashes_by_file = dict[str, set[str]]() hashes_by_file = dict[str, set[str]]()
current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries)) current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries))
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
for entry in tqdm(current_entries, desc="Hashing Entries"): for entry in tqdm(current_entries, desc="Hashing Entries"):
hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry)) hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry))
num_deleted_embeddings = 0 num_deleted_entries = 0
with timer("Preparing dataset for regeneration", logger): if regenerate:
if regenerate: with timer("Prepared dataset for regeneration in", logger):
logger.debug(f"Deleting all embeddings for file type {file_type}") logger.debug(f"Deleting all entries for file type {file_type}")
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type) num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
num_new_embeddings = 0 hashes_to_process = set()
with timer("Identify hashes for adding new entries", logger): with timer("Identified entries to add to database in", logger):
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"): for file in tqdm(hashes_by_file, desc="Identify new entries"):
hashes_for_file = hashes_by_file[file] hashes_for_file = hashes_by_file[file]
hashes_to_process = set()
existing_entries = DbEntry.objects.filter( existing_entries = DbEntry.objects.filter(
user=user, hashed_value__in=hashes_for_file, file_type=file_type user=user, hashed_value__in=hashes_for_file, file_type=file_type
) )
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries]) existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
hashes_to_process = hashes_for_file - existing_entry_hashes hashes_to_process |= hashes_for_file - existing_entry_hashes
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process] embeddings = []
data_to_embed = [getattr(entry, key) for entry in entries_to_process] with timer("Generated embeddings for entries to add to database in", logger):
embeddings = self.embeddings_model.embed_documents(data_to_embed) entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
embeddings += self.embeddings_model.embed_documents(data_to_embed)
with timer("Update the database with new vector embeddings", logger): added_entries: list[DbEntry] = []
num_items = len(hashes_to_process) with timer("Added entries to database in", logger):
assert num_items == len(embeddings) num_items = len(hashes_to_process)
batch_size = min(200, num_items) assert num_items == len(embeddings)
entry_batches = zip(hashes_to_process, embeddings) batch_size = min(200, num_items)
entry_batches = zip(hashes_to_process, embeddings)
for entry_batch in tqdm( for entry_batch in tqdm(batcher(entry_batches, batch_size), desc="Add entries to database"):
batcher(entry_batches, batch_size), desc="Processing embeddings in batches" batch_embeddings_to_create = []
): for entry_hash, new_entry in entry_batch:
batch_embeddings_to_create = [] entry = hash_to_current_entries[entry_hash]
for entry_hash, new_entry in entry_batch: batch_embeddings_to_create.append(
entry = hash_to_current_entries[entry_hash] DbEntry(
batch_embeddings_to_create.append( user=user,
DbEntry( embeddings=new_entry,
user=user, raw=entry.raw,
embeddings=new_entry, compiled=entry.compiled,
raw=entry.raw, heading=entry.heading[:1000], # Truncate to max chars of field allowed
compiled=entry.compiled, file_path=entry.file,
heading=entry.heading[:1000], # Truncate to max chars of field allowed file_type=file_type,
file_path=entry.file, hashed_value=entry_hash,
file_type=file_type, corpus_id=entry.corpus_id,
hashed_value=entry_hash, )
corpus_id=entry.corpus_id, )
) added_entries += DbEntry.objects.bulk_create(batch_embeddings_to_create)
) logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
new_entries = DbEntry.objects.bulk_create(batch_embeddings_to_create)
logger.debug(f"Created {len(new_entries)} new embeddings")
num_new_embeddings += len(new_entries)
dates_to_create = [] new_dates = []
with timer("Create new date associations for new embeddings", logger): with timer("Indexed dates from added entries in", logger):
for new_entry in new_entries: for added_entry in added_entries:
dates = self.date_filter.extract_dates(new_entry.raw) dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
for date in dates: dates_to_create = [
dates_to_create.append( EntryDates(date=date, entry=added_entry)
EntryDates( for date, added_entry in dates_in_entries
date=date, if not is_none_or_empty(date)
entry=new_entry, ]
) new_dates += EntryDates.objects.bulk_create(dates_to_create)
) logger.debug(f"Indexed {len(new_dates)} dates from added {file_type} entries")
new_dates = EntryDates.objects.bulk_create(dates_to_create)
if len(new_dates) > 0:
logger.debug(f"Created {len(new_dates)} new date entries")
with timer("Identify hashes for removed entries", logger): with timer("Deleted entries identified by server from database in", logger):
for file in hashes_by_file: for file in hashes_by_file:
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file) existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file] to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
num_deleted_embeddings += len(to_delete_entry_hashes) num_deleted_entries += len(to_delete_entry_hashes)
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes)) EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
with timer("Identify hashes for deleting entries", logger): with timer("Deleted entries requested by clients from database in", logger):
if deletion_filenames is not None: if deletion_filenames is not None:
for file_path in deletion_filenames: for file_path in deletion_filenames:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path) deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_embeddings += deleted_count num_deleted_entries += deleted_count
return num_new_embeddings, num_deleted_embeddings return len(added_entries), num_deleted_entries
@staticmethod @staticmethod
def mark_entries_for_update( def mark_entries_for_update(

View File

@@ -321,7 +321,6 @@ def load_content(
content_index: Optional[ContentIndex], content_index: Optional[ContentIndex],
search_models: SearchModels, search_models: SearchModels,
): ):
logger.info(f"Loading content from existing embeddings...")
if content_config is None: if content_config is None:
logger.warning("🚨 No Content configuration available.") logger.warning("🚨 No Content configuration available.")
return None return None

View File

@@ -207,7 +207,7 @@ def setup(
file_names = [file_name for file_name in files] file_names = [file_name for file_name in files]
logger.info( logger.info(
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}" f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
) )

View File

@@ -52,7 +52,8 @@ def cli(args=None):
args, remaining_args = parser.parse_known_args(args) args, remaining_args = parser.parse_known_args(args)
logger.debug(f"Ignoring unknown commandline args: {remaining_args}") if len(remaining_args) > 0:
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
# Set default values for arguments # Set default values for arguments
args.chat_on_gpu = not args.disable_chat_on_gpu args.chat_on_gpu = not args.disable_chat_on_gpu