From 965bd052f1eccdebc2388cb7eab1dddbbedf6c7c Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 5 Sep 2022 03:17:41 +0300 Subject: [PATCH] Make search filters return entry ids satisfying filter - Filter entries, embeddings by ids satisfying all filters in query func, after each filter has returned entry ids satisfying their individual acceptance criteria - Previously each filter would return a filtered list of entries. Each filter would be applied on entries filtered by previous filters. This made the filtering order dependent - Benefits - Filters can be applied independent of their order of execution - Precomputed indexes for each filter is not in danger of running into index out of bound errors, as filters run on original entries instead of on entries filtered by filters that have run before it - Extract entries satisfying filter only once instead of doing this for each filter - Costs - Each filter has to process all entries even if previous filters may have already marked them as non-satisfactory --- src/search_filter/date_filter.py | 16 +++------------- src/search_filter/file_filter.py | 27 +++++++++------------------ src/search_filter/word_filter.py | 19 ++++++------------- src/search_type/text_search.py | 15 ++++++++++++++- tests/test_date_filter.py | 30 ++++++++++++------------------ tests/test_file_filter.py | 30 ++++++++++++------------------ tests/test_word_filter.py | 20 ++++++++------------ 7 files changed, 64 insertions(+), 93 deletions(-) diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index 54a8b625..73feaeed 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -42,19 +42,15 @@ class DateFilter(BaseFilter): # if no date in query, return all entries if query_daterange is None: - return query, raw_entries, raw_embeddings + return query, set(range(len(raw_entries))) # remove date range filter from query query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces - # deep copy original embeddings, entries before filtering - embeddings= deepcopy(raw_embeddings) - entries = deepcopy(raw_entries) - # find entries containing any dates that fall with date range specified in query entries_to_include = set() - for id, entry in enumerate(entries): + for id, entry in enumerate(raw_entries): # Extract dates from entry for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): # Convert date string in entry to unix timestamp @@ -67,13 +63,7 @@ class DateFilter(BaseFilter): entries_to_include.add(id) break - # delete entries (and their embeddings) marked for exclusion - entries_to_exclude = set(range(len(entries))) - entries_to_include - for id in sorted(list(entries_to_exclude), reverse=True): - del entries[id] - embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) - - return query, entries, embeddings + return query, entries_to_include def extract_date_range(self, query): diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 45677bf6..3af67705 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -5,9 +5,6 @@ import time import logging from collections import defaultdict -# External Packages -import torch - # Internal Packages from src.search_filter.base_filter import BaseFilter from src.utils.helpers import LRU @@ -39,7 +36,7 @@ class FileFilter(BaseFilter): start = time.time() raw_files_to_search = re.findall(self.file_filter_regex, raw_query) if not raw_files_to_search: - return raw_query, raw_entries, raw_embeddings + return raw_query, set(range(len(raw_entries))) # Convert simple file filters with no path separator into regex # e.g. "file:notes.org" -> "file:.*notes.org" @@ -57,8 +54,11 @@ class FileFilter(BaseFilter): cache_key = tuple(files_to_search) if cache_key in self.cache: logger.info(f"Return file filter results from cache") - entries, embeddings = self.cache[cache_key] - return query, entries, embeddings + included_entry_indices = self.cache[cache_key] + return query, included_entry_indices + + if not self.file_to_entry_map: + self.load(raw_entries, regenerate=False) # Mark entries that contain any blocked_words for exclusion start = time.time() @@ -68,21 +68,12 @@ class FileFilter(BaseFilter): for search_file in files_to_search if fnmatch.fnmatch(entry_file, search_file)], set()) if not included_entry_indices: - return query, [], torch.empty(0) + return query, {} end = time.time() logger.debug(f"Mark entries satisfying filter: {end - start} seconds") - # Get entries (and associated embeddings) satisfying file filters - start = time.time() - - entries = [raw_entries[id] for id in included_entry_indices] - embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) - - end = time.time() - logger.debug(f"Keep entries satisfying filter: {end - start} seconds") - # Cache results - self.cache[cache_key] = entries, embeddings + self.cache[cache_key] = included_entry_indices - return query, entries, embeddings + return query, included_entry_indices diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py index 9f46edd2..bae6764e 100644 --- a/src/search_filter/word_filter.py +++ b/src/search_filter/word_filter.py @@ -78,14 +78,14 @@ class WordFilter(BaseFilter): logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") if len(required_words) == 0 and len(blocked_words) == 0: - return query, raw_entries, raw_embeddings + return query, set(range(len(raw_entries))) # Return item from cache if exists cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) if cache_key in self.cache: logger.info(f"Return word filter results from cache") - entries, embeddings = self.cache[cache_key] - return query, entries, embeddings + included_entry_indices = self.cache[cache_key] + return query, included_entry_indices if not self.word_to_entry_index: self.load(raw_entries, regenerate=False) @@ -105,17 +105,10 @@ class WordFilter(BaseFilter): end = time.time() logger.debug(f"Mark entries satisfying filter: {end - start} seconds") - # get entries (and their embeddings) satisfying inclusion and exclusion filters - start = time.time() - + # get entries satisfying inclusion and exclusion filters included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words - entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices] - embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) - - end = time.time() - logger.debug(f"Keep entries satisfying filter: {end - start} seconds") # Cache results - self.cache[cache_key] = entries, embeddings + self.cache[cache_key] = included_entry_indices - return query, entries, embeddings + return query, included_entry_indices diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 3b050bf0..eb7b0d34 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -78,8 +78,21 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): # Filter query, entries and embeddings before semantic search start = time.time() filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] + included_entry_indices = set(range(len(entries))) for filter in filters_in_query: - query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings) + query, included_entry_indices_by_filter = filter.apply(query, entries, corpus_embeddings) + included_entry_indices.intersection_update(included_entry_indices_by_filter) + + # Get entries (and associated embeddings) satisfying all filters + if not included_entry_indices: + return [], [] + else: + start = time.time() + entries = [entries[id] for id in included_entry_indices] + corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices))) + end = time.time() + logger.debug(f"Keep entries satisfying all filter: {end - start} seconds") + end = time.time() logger.debug(f"Total Filter Time: {end - start:.3f} seconds") diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index ddb1fcf0..0ac444fd 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -18,40 +18,34 @@ def test_date_filter(): {'compiled': '', 'raw': 'Entry with date:1984-04-02'}] q_with_no_date_filter = 'head tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_no_date_filter, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries, embeddings) assert ret_query == 'head tail' - assert len(ret_emb) == 3 - assert ret_entries == entries + assert entry_indices == {0, 1, 2} q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries, embeddings) assert ret_query == 'head tail' - assert len(ret_emb) == 0 - assert ret_entries == [] + assert entry_indices == set() query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) assert ret_query == 'head tail' - assert ret_entries == [entries[2]] - assert len(ret_emb) == 1 + assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) assert ret_query == 'head tail' - assert ret_entries == [entries[1]] - assert len(ret_emb) == 1 + assert entry_indices == {1} query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) assert ret_query == 'head tail' - assert ret_entries == [entries[2]] - assert len(ret_emb) == 1 + assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) assert ret_query == 'head tail' - assert ret_entries == [entries[1], entries[2]] - assert len(ret_emb) == 2 + assert entry_indices == {1, 2} def test_extract_date_range(): diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index b15b8a69..bde53ae0 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -13,13 +13,12 @@ def test_no_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == False assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries + assert entry_indices == {0, 1, 2, 3} def test_file_filter_with_non_existent_file(): @@ -30,13 +29,12 @@ def test_file_filter_with_non_existent_file(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 0 - assert ret_entries == [] + assert entry_indices == {} def test_single_file_filter(): @@ -47,13 +45,12 @@ def test_single_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[0], entries[2]] + assert entry_indices == {0, 2} def test_file_filter_with_partial_match(): @@ -64,13 +61,12 @@ def test_file_filter_with_partial_match(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[0], entries[2]] + assert entry_indices == {0, 2} def test_file_filter_with_regex_match(): @@ -81,13 +77,12 @@ def test_file_filter_with_regex_match(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries + assert entry_indices == {0, 1, 2, 3} def test_multiple_file_filter(): @@ -98,13 +93,12 @@ def test_multiple_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries + assert entry_indices == {0, 1, 2, 3} def arrange_content(): diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index 3d584077..95743bfd 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -14,13 +14,12 @@ def test_no_word_filter(tmp_path): # Act can_filter = word_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = word_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert assert can_filter == False assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries + assert entry_indices == {0, 1, 2, 3} def test_word_exclude_filter(tmp_path): @@ -31,13 +30,12 @@ def test_word_exclude_filter(tmp_path): # Act can_filter = word_filter.can_filter(q_with_exclude_filter) - ret_query, ret_entries, ret_emb = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[0], entries[2]] + assert entry_indices == {0, 2} def test_word_include_filter(tmp_path): @@ -48,13 +46,12 @@ def test_word_include_filter(tmp_path): # Act can_filter = word_filter.can_filter(query_with_include_filter) - ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 2 - assert ret_entries == [entries[2], entries[3]] + assert entry_indices == {2, 3} def test_word_include_and_exclude_filter(tmp_path): @@ -65,13 +62,12 @@ def test_word_include_and_exclude_filter(tmp_path): # Act can_filter = word_filter.can_filter(query_with_include_and_exclude_filter) - ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) # Assert assert can_filter == True assert ret_query == 'head tail' - assert len(ret_emb) == 1 - assert ret_entries == [entries[2]] + assert entry_indices == {2} def arrange_content():