diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 6aa7db78..45677bf6 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -3,6 +3,7 @@ import re import fnmatch import time import logging +from collections import defaultdict # External Packages import torch @@ -20,10 +21,15 @@ class FileFilter(BaseFilter): def __init__(self, entry_key='file'): self.entry_key = entry_key + self.file_to_entry_map = defaultdict(set) self.cache = LRU() - def load(self, *args, **kwargs): - pass + def load(self, entries, *args, **kwargs): + start = time.time() + for id, entry in enumerate(entries): + self.file_to_entry_map[entry[self.entry_key]].add(id) + end = time.time() + logger.debug(f"Created file filter index: {end - start} seconds") def can_filter(self, raw_query): return re.search(self.file_filter_regex, raw_query) is not None @@ -57,7 +63,10 @@ class FileFilter(BaseFilter): # Mark entries that contain any blocked_words for exclusion start = time.time() - included_entry_indices = [id for id, entry in enumerate(raw_entries) for search_file in files_to_search if fnmatch.fnmatch(entry[self.entry_key], search_file)] + included_entry_indices = set.union(*[self.file_to_entry_map[entry_file] + for entry_file in self.file_to_entry_map.keys() + for search_file in files_to_search + if fnmatch.fnmatch(entry_file, search_file)], set()) if not included_entry_indices: return query, [], torch.empty(0)