From a691ce4aa68225f1c300a50938ea00300d810aa3 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 27 Oct 2024 20:43:41 -0700 Subject: [PATCH] Batch entries into smaller groups to process --- .../commands/change_default_model.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/khoj/database/management/commands/change_default_model.py b/src/khoj/database/management/commands/change_default_model.py index cfa78581..d9a6359f 100644 --- a/src/khoj/database/management/commands/change_default_model.py +++ b/src/khoj/database/management/commands/change_default_model.py @@ -19,6 +19,8 @@ from khoj.processor.embeddings import EmbeddingsModel logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +BATCH_SIZE = 1000 # Define an appropriate batch size + class Command(BaseCommand): help = "Convert all existing Entry objects to use a new default Search model." @@ -42,22 +44,24 @@ class Command(BaseCommand): def handle(self, *args, **options): @transaction.atomic def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig): - entries = Entry.objects.filter(entry_filter).all() - compiled_entries = [entry.compiled for entry in entries] - updated_entries: List[Entry] = [] - try: - embeddings = embeddings_model.embed_documents(compiled_entries) + total_entries = Entry.objects.filter(entry_filter).count() + for start in tqdm(range(0, total_entries, BATCH_SIZE)): + end = start + BATCH_SIZE + entries = Entry.objects.filter(entry_filter)[start:end] + compiled_entries = [entry.compiled for entry in entries] + updated_entries: List[Entry] = [] + try: + embeddings = embeddings_model.embed_documents(compiled_entries) + except Exception as e: + logger.error(f"Error embedding documents: {e}") + return - except Exception as e: - logger.error(f"Error embedding documents: {e}") - return + for i, entry in enumerate(entries): + entry.embeddings = embeddings[i] + entry.search_model_id = search_model.id + updated_entries.append(entry) - for i, entry in enumerate(tqdm(entries)): - entry.embeddings = embeddings[i] - entry.search_model_id = search_model.id - updated_entries.append(entry) - - Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"]) + Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"]) search_model_config_id = options.get("search_model_id") apply = options.get("apply")