Do not pass embeddings as argument to filter.apply method

This commit is contained in:
Debanjum Singh Solanky
2022-09-05 15:46:54 +03:00
parent 965bd052f1
commit 31503e7afd
9 changed files with 27 additions and 28 deletions

View File

@@ -76,11 +76,11 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
# Filter query, entries and embeddings before semantic search
start = time.time()
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
start_filter = time.time()
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
for filter in filters_in_query:
query, included_entry_indices_by_filter = filter.apply(query, entries, corpus_embeddings)
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters
@@ -91,10 +91,10 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
logger.debug(f"Keep entries satisfying all filter: {end - start} seconds")
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
end = time.time()
logger.debug(f"Total Filter Time: {end - start:.3f} seconds")
end_filter = time.time()
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds")
if entries is None or len(entries) == 0:
return [], []