Make search filters return entry ids satisfying filter

- Filter entries, embeddings by ids satisfying all filters in query
  func, after each filter has returned entry ids satisfying their
  individual acceptance criteria

- Previously each filter would return a filtered list of entries.
  Each filter would be applied on entries filtered by previous filters.
  This made the filtering order dependent

- Benefits
  - Filters can be applied independent of their order of execution
  - Precomputed indexes for each filter is not in danger of running
    into index out of bound errors, as filters run on original entries
    instead of on entries filtered by filters that have run before it
  - Extract entries satisfying filter only once instead of doing
    this for each filter

- Costs
  - Each filter has to process all entries even if previous filters
    may have already marked them as non-satisfactory
This commit is contained in:
Debanjum Singh Solanky
2022-09-05 03:17:41 +03:00
parent 7dd20d764c
commit 965bd052f1
7 changed files with 64 additions and 93 deletions

View File

@@ -78,8 +78,21 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
# Filter query, entries and embeddings before semantic search
start = time.time()
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
included_entry_indices = set(range(len(entries)))
for filter in filters_in_query:
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
query, included_entry_indices_by_filter = filter.apply(query, entries, corpus_embeddings)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return [], []
else:
start = time.time()
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
logger.debug(f"Keep entries satisfying all filter: {end - start} seconds")
end = time.time()
logger.debug(f"Total Filter Time: {end - start:.3f} seconds")