diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index 54a8b625..73feaeed 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -42,19 +42,15 @@ class DateFilter(BaseFilter): # if no date in query, return all entries if query_daterange is None: - return query, raw_entries, raw_embeddings + return query, set(range(len(raw_entries))) # remove date range filter from query query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces - # deep copy original embeddings, entries before filtering - embeddings= deepcopy(raw_embeddings) - entries = deepcopy(raw_entries) - # find entries containing any dates that fall with date range specified in query entries_to_include = set() - for id, entry in enumerate(entries): + for id, entry in enumerate(raw_entries): # Extract dates from entry for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): # Convert date string in entry to unix timestamp @@ -67,13 +63,7 @@ class DateFilter(BaseFilter): entries_to_include.add(id) break - # delete entries (and their embeddings) marked for exclusion - entries_to_exclude = set(range(len(entries))) - entries_to_include - for id in sorted(list(entries_to_exclude), reverse=True): - del entries[id] - embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) - - return query, entries, embeddings + return query, entries_to_include def extract_date_range(self, query): diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 45677bf6..3af67705 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -5,9 +5,6 @@ import time import logging from collections import defaultdict -# External Packages -import torch - # Internal Packages from src.search_filter.base_filter import BaseFilter from src.utils.helpers import LRU @@ -39,7 +36,7 @@ class FileFilter(BaseFilter): start = time.time() raw_files_to_search = re.findall(self.file_filter_regex, raw_query) if not raw_files_to_search: - return raw_query, raw_entries, raw_embeddings + return raw_query, set(range(len(raw_entries))) # Convert simple file filters with no path separator into regex # e.g. "file:notes.org" -> "file:.*notes.org" @@ -57,8 +54,11 @@ class FileFilter(BaseFilter): cache_key = tuple(files_to_search) if cache_key in self.cache: logger.info(f"Return file filter results from cache") - entries, embeddings = self.cache[cache_key] - return query, entries, embeddings + included_entry_indices = self.cache[cache_key] + return query, included_entry_indices + + if not self.file_to_entry_map: + self.load(raw_entries, regenerate=False) # Mark entries that contain any blocked_words for exclusion start = time.time() @@ -68,21 +68,12 @@ class FileFilter(BaseFilter): for search_file in files_to_search if fnmatch.fnmatch(entry_file, search_file)], set()) if not included_entry_indices: - return query, [], torch.empty(0) + return query, {} end = time.time() logger.debug(f"Mark entries satisfying filter: {end - start} seconds") - # Get entries (and associated embeddings) satisfying file filters - start = time.time() - - entries = [raw_entries[id] for id in included_entry_indices] - embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) - - end = time.time() - logger.debug(f"Keep entries satisfying filter: {end - start} seconds") - # Cache results - self.cache[cache_key] = entries, embeddings + self.cache[cache_key] = included_entry_indices - return query, entries, embeddings + return query, included_entry_indices diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py index 9f46edd2..bae6764e 100644 --- a/src/search_filter/word_filter.py +++ b/src/search_filter/word_filter.py @@ -78,14 +78,14 @@ class WordFilter(BaseFilter): logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") if len(required_words) == 0 and len(blocked_words) == 0: - return query, raw_entries, raw_embeddings + return query, set(range(len(raw_entries))) # Return item from cache if exists cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) if cache_key in self.cache: logger.info(f"Return word filter results from cache") - entries, embeddings = self.cache[cache_key] - return query, entries, embeddings + included_entry_indices = self.cache[cache_key] + return query, included_entry_indices if not self.word_to_entry_index: self.load(raw_entries, regenerate=False) @@ -105,17 +105,10 @@ class WordFilter(BaseFilter): end = time.time() logger.debug(f"Mark entries satisfying filter: {end - start} seconds") - # get entries (and their embeddings) satisfying inclusion and exclusion filters - start = time.time() - + # get entries satisfying inclusion and exclusion filters included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words - entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices] - embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) - - end = time.time() - logger.debug(f"Keep entries satisfying filter: {end - start} seconds") # Cache results - self.cache[cache_key] = entries, embeddings + self.cache[cache_key] = included_entry_indices - return query, entries, embeddings + return query, included_entry_indices diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 3b050bf0..eb7b0d34 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -78,8 +78,21 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): # Filter query, entries and embeddings before semantic search start = time.time() filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] + included_entry_indices = set(range(len(entries))) for filter in filters_in_query: - query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings) + query, included_entry_indices_by_filter = filter.apply(query, entries, corpus_embeddings) + included_entry_indices.intersection_update(included_entry_indices_by_filter) + + # Get entries (and associated embeddings) satisfying all filters + if not included_entry_indices: + return [], [] + else: + start = time.time() + entries = [entries[id] for id in included_entry_indices] + corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices))) + end = time.time() + logger.debug(f"Keep entries satisfying all filter: {end - start} seconds") + end = time.time() logger.debug(f"Total Filter Time: {end - start:.3f} seconds") diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index ddb1fcf0..0ac444fd 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -18,40 +18,34 @@ def test_date_filter(): {'compiled': '', 'raw': 'Entry with date:1984-04-02'}] q_with_no_date_filter = 'head tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_no_date_filter, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries, embeddings) assert ret_query == 'head tail' - assert len(ret_emb) == 3 - assert ret_entries == entries + assert entry_indices == {0, 1, 2} q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries, embeddings) assert ret_query == 'head tail' - assert len(ret_emb) == 0 - assert ret_entries == [] + assert entry_indices == set() query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) assert ret_query == 'head tail' - assert ret_entries == [entries[2]] - assert len(ret_emb) == 1 + assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) assert ret_query == 'head tail' - assert ret_entries == [entries[1]] - assert len(ret_emb) == 1 + assert entry_indices == {1} query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) assert ret_query == 'head tail' - assert ret_entries == [entries[2]] - assert len(ret_emb) == 1 + assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) assert ret_query == 'head tail' - assert ret_entries == [entries[1], entries[2]] - assert len(ret_emb) == 2 + assert entry_indices == {1, 2} def test_extract_date_range(): diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index b15b8a69..bde53ae0 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -13,13 +13,12 @@ def test_no_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == False assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries + assert entry_indices == {0, 1, 2, 3} def test_file_filter_with_non_existent_file(): @@ -30,13 +29,12 @@ def test_file_filter_with_non_existent_file(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 0 - assert ret_entries == [] + assert entry_indices == {} def test_single_file_filter(): @@ -47,13 +45,12 @@ def test_single_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[0], entries[2]] + assert entry_indices == {0, 2} def test_file_filter_with_partial_match(): @@ -64,13 +61,12 @@ def test_file_filter_with_partial_match(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[0], entries[2]] + assert entry_indices == {0, 2} def test_file_filter_with_regex_match(): @@ -81,13 +77,12 @@ def test_file_filter_with_regex_match(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries + assert entry_indices == {0, 1, 2, 3} def test_multiple_file_filter(): @@ -98,13 +93,12 @@ def test_multiple_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries + assert entry_indices == {0, 1, 2, 3} def arrange_content(): diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index 3d584077..95743bfd 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -14,13 +14,12 @@ def test_no_word_filter(tmp_path): # Act can_filter = word_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = word_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == False assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries + assert entry_indices == {0, 1, 2, 3} def test_word_exclude_filter(tmp_path): @@ -31,13 +30,12 @@ def test_word_exclude_filter(tmp_path): # Act can_filter = word_filter.can_filter(q_with_exclude_filter) - ret_query, ret_entries, ret_emb = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[0], entries[2]] + assert entry_indices == {0, 2} def test_word_include_filter(tmp_path): @@ -48,13 +46,12 @@ def test_word_include_filter(tmp_path): # Act can_filter = word_filter.can_filter(query_with_include_filter) - ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[2], entries[3]] + assert entry_indices == {2, 3} def test_word_include_and_exclude_filter(tmp_path): @@ -65,13 +62,12 @@ def test_word_include_and_exclude_filter(tmp_path): # Act can_filter = word_filter.can_filter(query_with_include_and_exclude_filter) - ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 1 - assert ret_entries == [entries[2]] + assert entry_indices == {2} def arrange_content():