mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Flatten nested loops, improve progress reporting in text_to_jsonl indexer
Flatten the nested loops to improve visibilty into indexing progress Reduce spurious logs, report the logs at aggregated level and update the logging description text to improve indexing progress reporting
This commit is contained in:
@@ -1,11 +1,12 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from itertools import repeat
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Callable, List, Tuple, Set, Any
|
from typing import Callable, List, Tuple, Set, Any
|
||||||
from khoj.utils.helpers import timer, batcher
|
from khoj.utils.helpers import is_none_or_empty, timer, batcher
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
@@ -83,92 +84,88 @@ class TextToEntries(ABC):
|
|||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
regenerate: bool = False,
|
regenerate: bool = False,
|
||||||
):
|
):
|
||||||
with timer("Construct current entry hashes", logger):
|
with timer("Constructed current entry hashes in", logger):
|
||||||
hashes_by_file = dict[str, set[str]]()
|
hashes_by_file = dict[str, set[str]]()
|
||||||
current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries))
|
current_entry_hashes = list(map(TextToEntries.hash_func(key), current_entries))
|
||||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||||
for entry in tqdm(current_entries, desc="Hashing Entries"):
|
for entry in tqdm(current_entries, desc="Hashing Entries"):
|
||||||
hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry))
|
hashes_by_file.setdefault(entry.file, set()).add(TextToEntries.hash_func(key)(entry))
|
||||||
|
|
||||||
num_deleted_embeddings = 0
|
num_deleted_entries = 0
|
||||||
with timer("Preparing dataset for regeneration", logger):
|
if regenerate:
|
||||||
if regenerate:
|
with timer("Prepared dataset for regeneration in", logger):
|
||||||
logger.debug(f"Deleting all embeddings for file type {file_type}")
|
logger.debug(f"Deleting all entries for file type {file_type}")
|
||||||
num_deleted_embeddings = EntryAdapters.delete_all_entries(user, file_type)
|
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type)
|
||||||
|
|
||||||
num_new_embeddings = 0
|
hashes_to_process = set()
|
||||||
with timer("Identify hashes for adding new entries", logger):
|
with timer("Identified entries to add to database in", logger):
|
||||||
for file in tqdm(hashes_by_file, desc="Processing file with hashed values"):
|
for file in tqdm(hashes_by_file, desc="Identify new entries"):
|
||||||
hashes_for_file = hashes_by_file[file]
|
hashes_for_file = hashes_by_file[file]
|
||||||
hashes_to_process = set()
|
|
||||||
existing_entries = DbEntry.objects.filter(
|
existing_entries = DbEntry.objects.filter(
|
||||||
user=user, hashed_value__in=hashes_for_file, file_type=file_type
|
user=user, hashed_value__in=hashes_for_file, file_type=file_type
|
||||||
)
|
)
|
||||||
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
existing_entry_hashes = set([entry.hashed_value for entry in existing_entries])
|
||||||
hashes_to_process = hashes_for_file - existing_entry_hashes
|
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
||||||
|
|
||||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
embeddings = []
|
||||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
with timer("Generated embeddings for entries to add to database in", logger):
|
||||||
embeddings = self.embeddings_model.embed_documents(data_to_embed)
|
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||||
|
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||||
|
embeddings += self.embeddings_model.embed_documents(data_to_embed)
|
||||||
|
|
||||||
with timer("Update the database with new vector embeddings", logger):
|
added_entries: list[DbEntry] = []
|
||||||
num_items = len(hashes_to_process)
|
with timer("Added entries to database in", logger):
|
||||||
assert num_items == len(embeddings)
|
num_items = len(hashes_to_process)
|
||||||
batch_size = min(200, num_items)
|
assert num_items == len(embeddings)
|
||||||
entry_batches = zip(hashes_to_process, embeddings)
|
batch_size = min(200, num_items)
|
||||||
|
entry_batches = zip(hashes_to_process, embeddings)
|
||||||
|
|
||||||
for entry_batch in tqdm(
|
for entry_batch in tqdm(batcher(entry_batches, batch_size), desc="Add entries to database"):
|
||||||
batcher(entry_batches, batch_size), desc="Processing embeddings in batches"
|
batch_embeddings_to_create = []
|
||||||
):
|
for entry_hash, new_entry in entry_batch:
|
||||||
batch_embeddings_to_create = []
|
entry = hash_to_current_entries[entry_hash]
|
||||||
for entry_hash, new_entry in entry_batch:
|
batch_embeddings_to_create.append(
|
||||||
entry = hash_to_current_entries[entry_hash]
|
DbEntry(
|
||||||
batch_embeddings_to_create.append(
|
user=user,
|
||||||
DbEntry(
|
embeddings=new_entry,
|
||||||
user=user,
|
raw=entry.raw,
|
||||||
embeddings=new_entry,
|
compiled=entry.compiled,
|
||||||
raw=entry.raw,
|
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
||||||
compiled=entry.compiled,
|
file_path=entry.file,
|
||||||
heading=entry.heading[:1000], # Truncate to max chars of field allowed
|
file_type=file_type,
|
||||||
file_path=entry.file,
|
hashed_value=entry_hash,
|
||||||
file_type=file_type,
|
corpus_id=entry.corpus_id,
|
||||||
hashed_value=entry_hash,
|
)
|
||||||
corpus_id=entry.corpus_id,
|
)
|
||||||
)
|
added_entries += DbEntry.objects.bulk_create(batch_embeddings_to_create)
|
||||||
)
|
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
|
||||||
new_entries = DbEntry.objects.bulk_create(batch_embeddings_to_create)
|
|
||||||
logger.debug(f"Created {len(new_entries)} new embeddings")
|
|
||||||
num_new_embeddings += len(new_entries)
|
|
||||||
|
|
||||||
dates_to_create = []
|
new_dates = []
|
||||||
with timer("Create new date associations for new embeddings", logger):
|
with timer("Indexed dates from added entries in", logger):
|
||||||
for new_entry in new_entries:
|
for added_entry in added_entries:
|
||||||
dates = self.date_filter.extract_dates(new_entry.raw)
|
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
|
||||||
for date in dates:
|
dates_to_create = [
|
||||||
dates_to_create.append(
|
EntryDates(date=date, entry=added_entry)
|
||||||
EntryDates(
|
for date, added_entry in dates_in_entries
|
||||||
date=date,
|
if not is_none_or_empty(date)
|
||||||
entry=new_entry,
|
]
|
||||||
)
|
new_dates += EntryDates.objects.bulk_create(dates_to_create)
|
||||||
)
|
logger.debug(f"Indexed {len(new_dates)} dates from added {file_type} entries")
|
||||||
new_dates = EntryDates.objects.bulk_create(dates_to_create)
|
|
||||||
if len(new_dates) > 0:
|
|
||||||
logger.debug(f"Created {len(new_dates)} new date entries")
|
|
||||||
|
|
||||||
with timer("Identify hashes for removed entries", logger):
|
with timer("Deleted entries identified by server from database in", logger):
|
||||||
for file in hashes_by_file:
|
for file in hashes_by_file:
|
||||||
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
|
existing_entry_hashes = EntryAdapters.get_existing_entry_hashes_by_file(user, file)
|
||||||
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
|
to_delete_entry_hashes = set(existing_entry_hashes) - hashes_by_file[file]
|
||||||
num_deleted_embeddings += len(to_delete_entry_hashes)
|
num_deleted_entries += len(to_delete_entry_hashes)
|
||||||
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
EntryAdapters.delete_entry_by_hash(user, hashed_values=list(to_delete_entry_hashes))
|
||||||
|
|
||||||
with timer("Identify hashes for deleting entries", logger):
|
with timer("Deleted entries requested by clients from database in", logger):
|
||||||
if deletion_filenames is not None:
|
if deletion_filenames is not None:
|
||||||
for file_path in deletion_filenames:
|
for file_path in deletion_filenames:
|
||||||
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
||||||
num_deleted_embeddings += deleted_count
|
num_deleted_entries += deleted_count
|
||||||
|
|
||||||
return num_new_embeddings, num_deleted_embeddings
|
return len(added_entries), num_deleted_entries
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def mark_entries_for_update(
|
def mark_entries_for_update(
|
||||||
|
|||||||
@@ -321,7 +321,6 @@ def load_content(
|
|||||||
content_index: Optional[ContentIndex],
|
content_index: Optional[ContentIndex],
|
||||||
search_models: SearchModels,
|
search_models: SearchModels,
|
||||||
):
|
):
|
||||||
logger.info(f"Loading content from existing embeddings...")
|
|
||||||
if content_config is None:
|
if content_config is None:
|
||||||
logger.warning("🚨 No Content configuration available.")
|
logger.warning("🚨 No Content configuration available.")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ def setup(
|
|||||||
file_names = [file_name for file_name in files]
|
file_names = [file_name for file_name in files]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Created {num_new_embeddings} new embeddings. Deleted {num_deleted_embeddings} embeddings for user {user} and files {file_names}"
|
f"Deleted {num_deleted_embeddings} entries. Created {num_new_embeddings} new entries for user {user} from files {file_names}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,8 @@ def cli(args=None):
|
|||||||
|
|
||||||
args, remaining_args = parser.parse_known_args(args)
|
args, remaining_args = parser.parse_known_args(args)
|
||||||
|
|
||||||
logger.debug(f"Ignoring unknown commandline args: {remaining_args}")
|
if len(remaining_args) > 0:
|
||||||
|
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
|
||||||
|
|
||||||
# Set default values for arguments
|
# Set default values for arguments
|
||||||
args.chat_on_gpu = not args.disable_chat_on_gpu
|
args.chat_on_gpu = not args.disable_chat_on_gpu
|
||||||
|
|||||||
Reference in New Issue
Block a user