Make search filters return entry ids satisfying filter

- Filter entries, embeddings by ids satisfying all filters in query
  func, after each filter has returned entry ids satisfying their
  individual acceptance criteria

- Previously each filter would return a filtered list of entries.
  Each filter would be applied on entries filtered by previous filters.
  This made the filtering order dependent

- Benefits
  - Filters can be applied independent of their order of execution
  - Precomputed indexes for each filter is not in danger of running
    into index out of bound errors, as filters run on original entries
    instead of on entries filtered by filters that have run before it
  - Extract entries satisfying filter only once instead of doing
    this for each filter

- Costs
  - Each filter has to process all entries even if previous filters
    may have already marked them as non-satisfactory
This commit is contained in:
Debanjum Singh Solanky
2022-09-05 03:17:41 +03:00
parent 7dd20d764c
commit 965bd052f1
7 changed files with 64 additions and 93 deletions

View File

@@ -42,19 +42,15 @@ class DateFilter(BaseFilter):
# if no date in query, return all entries # if no date in query, return all entries
if query_daterange is None: if query_daterange is None:
return query, raw_entries, raw_embeddings return query, set(range(len(raw_entries)))
# remove date range filter from query # remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(rf'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# deep copy original embeddings, entries before filtering
embeddings= deepcopy(raw_embeddings)
entries = deepcopy(raw_entries)
# find entries containing any dates that fall with date range specified in query # find entries containing any dates that fall with date range specified in query
entries_to_include = set() entries_to_include = set()
for id, entry in enumerate(entries): for id, entry in enumerate(raw_entries):
# Extract dates from entry # Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
# Convert date string in entry to unix timestamp # Convert date string in entry to unix timestamp
@@ -67,13 +63,7 @@ class DateFilter(BaseFilter):
entries_to_include.add(id) entries_to_include.add(id)
break break
# delete entries (and their embeddings) marked for exclusion return query, entries_to_include
entries_to_exclude = set(range(len(entries))) - entries_to_include
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
return query, entries, embeddings
def extract_date_range(self, query): def extract_date_range(self, query):

View File

@@ -5,9 +5,6 @@ import time
import logging import logging
from collections import defaultdict from collections import defaultdict
# External Packages
import torch
# Internal Packages # Internal Packages
from src.search_filter.base_filter import BaseFilter from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU from src.utils.helpers import LRU
@@ -39,7 +36,7 @@ class FileFilter(BaseFilter):
start = time.time() start = time.time()
raw_files_to_search = re.findall(self.file_filter_regex, raw_query) raw_files_to_search = re.findall(self.file_filter_regex, raw_query)
if not raw_files_to_search: if not raw_files_to_search:
return raw_query, raw_entries, raw_embeddings return raw_query, set(range(len(raw_entries)))
# Convert simple file filters with no path separator into regex # Convert simple file filters with no path separator into regex
# e.g. "file:notes.org" -> "file:.*notes.org" # e.g. "file:notes.org" -> "file:.*notes.org"
@@ -57,8 +54,11 @@ class FileFilter(BaseFilter):
cache_key = tuple(files_to_search) cache_key = tuple(files_to_search)
if cache_key in self.cache: if cache_key in self.cache:
logger.info(f"Return file filter results from cache") logger.info(f"Return file filter results from cache")
entries, embeddings = self.cache[cache_key] included_entry_indices = self.cache[cache_key]
return query, entries, embeddings return query, included_entry_indices
if not self.file_to_entry_map:
self.load(raw_entries, regenerate=False)
# Mark entries that contain any blocked_words for exclusion # Mark entries that contain any blocked_words for exclusion
start = time.time() start = time.time()
@@ -68,21 +68,12 @@ class FileFilter(BaseFilter):
for search_file in files_to_search for search_file in files_to_search
if fnmatch.fnmatch(entry_file, search_file)], set()) if fnmatch.fnmatch(entry_file, search_file)], set())
if not included_entry_indices: if not included_entry_indices:
return query, [], torch.empty(0) return query, {}
end = time.time() end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds") logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# Get entries (and associated embeddings) satisfying file filters
start = time.time()
entries = [raw_entries[id] for id in included_entry_indices]
embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
logger.debug(f"Keep entries satisfying filter: {end - start} seconds")
# Cache results # Cache results
self.cache[cache_key] = entries, embeddings self.cache[cache_key] = included_entry_indices
return query, entries, embeddings return query, included_entry_indices

View File

@@ -78,14 +78,14 @@ class WordFilter(BaseFilter):
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
if len(required_words) == 0 and len(blocked_words) == 0: if len(required_words) == 0 and len(blocked_words) == 0:
return query, raw_entries, raw_embeddings return query, set(range(len(raw_entries)))
# Return item from cache if exists # Return item from cache if exists
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
if cache_key in self.cache: if cache_key in self.cache:
logger.info(f"Return word filter results from cache") logger.info(f"Return word filter results from cache")
entries, embeddings = self.cache[cache_key] included_entry_indices = self.cache[cache_key]
return query, entries, embeddings return query, included_entry_indices
if not self.word_to_entry_index: if not self.word_to_entry_index:
self.load(raw_entries, regenerate=False) self.load(raw_entries, regenerate=False)
@@ -105,17 +105,10 @@ class WordFilter(BaseFilter):
end = time.time() end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds") logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# get entries (and their embeddings) satisfying inclusion and exclusion filters # get entries satisfying inclusion and exclusion filters
start = time.time()
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices]
embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
logger.debug(f"Keep entries satisfying filter: {end - start} seconds")
# Cache results # Cache results
self.cache[cache_key] = entries, embeddings self.cache[cache_key] = included_entry_indices
return query, entries, embeddings return query, included_entry_indices

View File

@@ -78,8 +78,21 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
# Filter query, entries and embeddings before semantic search # Filter query, entries and embeddings before semantic search
start = time.time() start = time.time()
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
included_entry_indices = set(range(len(entries)))
for filter in filters_in_query: for filter in filters_in_query:
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings) query, included_entry_indices_by_filter = filter.apply(query, entries, corpus_embeddings)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return [], []
else:
start = time.time()
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
logger.debug(f"Keep entries satisfying all filter: {end - start} seconds")
end = time.time() end = time.time()
logger.debug(f"Total Filter Time: {end - start:.3f} seconds") logger.debug(f"Total Filter Time: {end - start:.3f} seconds")

View File

@@ -18,40 +18,34 @@ def test_date_filter():
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}] {'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
q_with_no_date_filter = 'head tail' q_with_no_date_filter = 'head tail'
ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_no_date_filter, entries.copy(), embeddings) ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries, embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 3 assert entry_indices == {0, 1, 2}
assert ret_entries == entries
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries, embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 0 assert entry_indices == set()
assert ret_entries == []
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[2]] assert entry_indices == {2}
assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[1]] assert entry_indices == {1}
assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[2]] assert entry_indices == {2}
assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[1], entries[2]] assert entry_indices == {1, 2}
assert len(ret_emb) == 2
def test_extract_date_range(): def test_extract_date_range():

View File

@@ -13,13 +13,12 @@ def test_no_file_filter():
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == False assert can_filter == False
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 4 assert entry_indices == {0, 1, 2, 3}
assert ret_entries == entries
def test_file_filter_with_non_existent_file(): def test_file_filter_with_non_existent_file():
@@ -30,13 +29,12 @@ def test_file_filter_with_non_existent_file():
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 0 assert entry_indices == {}
assert ret_entries == []
def test_single_file_filter(): def test_single_file_filter():
@@ -47,13 +45,12 @@ def test_single_file_filter():
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 2 assert entry_indices == {0, 2}
assert ret_entries == [entries[0], entries[2]]
def test_file_filter_with_partial_match(): def test_file_filter_with_partial_match():
@@ -64,13 +61,12 @@ def test_file_filter_with_partial_match():
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 2 assert entry_indices == {0, 2}
assert ret_entries == [entries[0], entries[2]]
def test_file_filter_with_regex_match(): def test_file_filter_with_regex_match():
@@ -81,13 +77,12 @@ def test_file_filter_with_regex_match():
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 4 assert entry_indices == {0, 1, 2, 3}
assert ret_entries == entries
def test_multiple_file_filter(): def test_multiple_file_filter():
@@ -98,13 +93,12 @@ def test_multiple_file_filter():
# Act # Act
can_filter = file_filter.can_filter(q_with_no_filter) can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 4 assert entry_indices == {0, 1, 2, 3}
assert ret_entries == entries
def arrange_content(): def arrange_content():

View File

@@ -14,13 +14,12 @@ def test_no_word_filter(tmp_path):
# Act # Act
can_filter = word_filter.can_filter(q_with_no_filter) can_filter = word_filter.can_filter(q_with_no_filter)
ret_query, ret_entries, ret_emb = word_filter.apply(q_with_no_filter, entries.copy(), embeddings) ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == False assert can_filter == False
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 4 assert entry_indices == {0, 1, 2, 3}
assert ret_entries == entries
def test_word_exclude_filter(tmp_path): def test_word_exclude_filter(tmp_path):
@@ -31,13 +30,12 @@ def test_word_exclude_filter(tmp_path):
# Act # Act
can_filter = word_filter.can_filter(q_with_exclude_filter) can_filter = word_filter.can_filter(q_with_exclude_filter)
ret_query, ret_entries, ret_emb = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 2 assert entry_indices == {0, 2}
assert ret_entries == [entries[0], entries[2]]
def test_word_include_filter(tmp_path): def test_word_include_filter(tmp_path):
@@ -48,13 +46,12 @@ def test_word_include_filter(tmp_path):
# Act # Act
can_filter = word_filter.can_filter(query_with_include_filter) can_filter = word_filter.can_filter(query_with_include_filter)
ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_filter, entries.copy(), embeddings) ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 2 assert entry_indices == {2, 3}
assert ret_entries == [entries[2], entries[3]]
def test_word_include_and_exclude_filter(tmp_path): def test_word_include_and_exclude_filter(tmp_path):
@@ -65,13 +62,12 @@ def test_word_include_and_exclude_filter(tmp_path):
# Act # Act
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter) can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == True assert can_filter == True
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 1 assert entry_indices == {2}
assert ret_entries == [entries[2]]
def arrange_content(): def arrange_content():