diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index cfbe7ca6..20019c56 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1012,7 +1012,7 @@ class EntryAdapters: return deleted_count @staticmethod - def get_entries_by_batch(user: KhojUser, batch_size: int, file_type: str = None, file_source: str = None): + def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None): queryset = Entry.objects.filter(user=user) if file_type is not None: @@ -1021,14 +1021,15 @@ class EntryAdapters: if file_source is not None: queryset = queryset.filter(file_source=file_source) - while queryset.exists(): - batch_ids = list(queryset.values_list("id", flat=True)[:batch_size]) - yield Entry.objects.filter(id__in=batch_ids) + return queryset @staticmethod def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): deleted_count = 0 - for batch in EntryAdapters.get_entries_by_batch(user, batch_size, file_type, file_source): + queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) + while queryset.exists(): + batch_ids = list(queryset.values_list("id", flat=True)[:batch_size]) + batch = Entry.objects.filter(id__in=batch_ids, user=user) count, _ = batch.delete() deleted_count += count return deleted_count @@ -1036,7 +1037,10 @@ class EntryAdapters: @staticmethod async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000): deleted_count = 0 - async for batch in EntryAdapters.get_entries_by_batch(user, batch_size, file_type, file_source): + queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source) + while await queryset.aexists(): + batch_ids = await sync_to_async(list)(queryset.values_list("id", flat=True)[:batch_size]) + batch = Entry.objects.filter(id__in=batch_ids, user=user) count, _ = await batch.adelete() deleted_count += count return deleted_count