Do not pass embeddings as argument to filter.apply method

This commit is contained in:
Debanjum Singh Solanky
2022-09-05 15:46:54 +03:00
parent 965bd052f1
commit 31503e7afd
9 changed files with 27 additions and 28 deletions

View File

@@ -1,6 +1,6 @@
# Standard Packages
from abc import ABC, abstractmethod
from typing import List, Tuple
from typing import List, Set, Tuple
# External Packages
import torch
@@ -16,5 +16,5 @@ class BaseFilter(ABC):
pass
@abstractmethod
def apply(self, query:str, raw_entries:List[str], raw_embeddings: torch.Tensor) -> Tuple[str, List[str], torch.Tensor]:
def apply(self, query:str, raw_entries:List[str]) -> Tuple[str, Set[int]]:
pass

View File

@@ -35,7 +35,7 @@ class DateFilter(BaseFilter):
return self.extract_date_range(raw_query) is not None
def apply(self, query, raw_entries, raw_embeddings):
def apply(self, query, raw_entries):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
query_daterange = self.extract_date_range(query)

View File

@@ -31,7 +31,7 @@ class FileFilter(BaseFilter):
def can_filter(self, raw_query):
return re.search(self.file_filter_regex, raw_query) is not None
def apply(self, raw_query, raw_entries, raw_embeddings):
def apply(self, raw_query, raw_entries):
# Extract file filters from raw query
start = time.time()
raw_files_to_search = re.findall(self.file_filter_regex, raw_query)

View File

@@ -65,7 +65,7 @@ class WordFilter(BaseFilter):
return len(required_words) != 0 or len(blocked_words) != 0
def apply(self, raw_query, raw_entries, raw_embeddings):
def apply(self, raw_query, raw_entries):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters
start = time.time()

View File

@@ -76,11 +76,11 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
# Filter query, entries and embeddings before semantic search
start = time.time()
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
start_filter = time.time()
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
for filter in filters_in_query:
query, included_entry_indices_by_filter = filter.apply(query, entries, corpus_embeddings)
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters
@@ -91,10 +91,10 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
logger.debug(f"Keep entries satisfying all filter: {end - start} seconds")
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
end = time.time()
logger.debug(f"Total Filter Time: {end - start:.3f} seconds")
end_filter = time.time()
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds")
if entries is None or len(entries) == 0:
return [], []