diff --git a/src/search_filter/base_filter.py b/src/search_filter/base_filter.py index dc079b45..735b6915 100644 --- a/src/search_filter/base_filter.py +++ b/src/search_filter/base_filter.py @@ -1,6 +1,6 @@ # Standard Packages from abc import ABC, abstractmethod -from typing import List, Tuple +from typing import List, Set, Tuple # External Packages import torch @@ -16,5 +16,5 @@ class BaseFilter(ABC): pass @abstractmethod - def apply(self, query:str, raw_entries:List[str], raw_embeddings: torch.Tensor) -> Tuple[str, List[str], torch.Tensor]: + def apply(self, query:str, raw_entries:List[str]) -> Tuple[str, Set[int]]: pass \ No newline at end of file diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index 73feaeed..683d7a64 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -35,7 +35,7 @@ class DateFilter(BaseFilter): return self.extract_date_range(raw_query) is not None - def apply(self, query, raw_entries, raw_embeddings): + def apply(self, query, raw_entries): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query query_daterange = self.extract_date_range(query) diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 3af67705..41f80274 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -31,7 +31,7 @@ class FileFilter(BaseFilter): def can_filter(self, raw_query): return re.search(self.file_filter_regex, raw_query) is not None - def apply(self, raw_query, raw_entries, raw_embeddings): + def apply(self, raw_query, raw_entries): # Extract file filters from raw query start = time.time() raw_files_to_search = re.findall(self.file_filter_regex, raw_query) diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py index bae6764e..dcf9ca6b 100644 --- a/src/search_filter/word_filter.py +++ b/src/search_filter/word_filter.py @@ -65,7 +65,7 @@ class WordFilter(BaseFilter): return len(required_words) != 0 or len(blocked_words) != 0 - def apply(self, raw_query, raw_entries, raw_embeddings): + def apply(self, raw_query, raw_entries): "Find entries containing required and not blocked words specified in query" # Separate natural query from required, blocked words filters start = time.time() diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index eb7b0d34..8666056c 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -76,11 +76,11 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings # Filter query, entries and embeddings before semantic search - start = time.time() - filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] + start_filter = time.time() included_entry_indices = set(range(len(entries))) + filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] for filter in filters_in_query: - query, included_entry_indices_by_filter = filter.apply(query, entries, corpus_embeddings) + query, included_entry_indices_by_filter = filter.apply(query, entries) included_entry_indices.intersection_update(included_entry_indices_by_filter) # Get entries (and associated embeddings) satisfying all filters @@ -91,10 +91,10 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): entries = [entries[id] for id in included_entry_indices] corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices))) end = time.time() - logger.debug(f"Keep entries satisfying all filter: {end - start} seconds") + logger.debug(f"Keep entries satisfying all filters: {end - start} seconds") - end = time.time() - logger.debug(f"Total Filter Time: {end - start:.3f} seconds") + end_filter = time.time() + logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds") if entries is None or len(entries) == 0: return [], [] diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 0ac444fd..345c5c4f 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -18,32 +18,32 @@ def test_date_filter(): {'compiled': '', 'raw': 'Entry with date:1984-04-02'}] q_with_no_date_filter = 'head tail' - ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries, embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries) assert ret_query == 'head tail' assert entry_indices == {0, 1, 2} q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' - ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries, embeddings) + ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries) assert ret_query == 'head tail' assert entry_indices == set() query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' - ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' - ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' assert entry_indices == {1} query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' - ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' assert entry_indices == {2} query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' - ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings) + ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries) assert ret_query == 'head tail' assert entry_indices == {1, 2} diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index bde53ae0..3f9c22b3 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -13,7 +13,7 @@ def test_no_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == False @@ -29,7 +29,7 @@ def test_file_filter_with_non_existent_file(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True @@ -45,7 +45,7 @@ def test_single_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True @@ -61,7 +61,7 @@ def test_file_filter_with_partial_match(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True @@ -77,7 +77,7 @@ def test_file_filter_with_regex_match(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True @@ -93,7 +93,7 @@ def test_multiple_file_filter(): # Act can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == True diff --git a/tests/test_text_search.py b/tests/test_text_search.py index d56d304d..39fed92e 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -1,6 +1,5 @@ # System Packages from pathlib import Path -from src.utils.config import SearchType # Internal Packages from src.utils.state import model diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index 95743bfd..3efe8ed9 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -14,7 +14,7 @@ def test_no_word_filter(tmp_path): # Act can_filter = word_filter.can_filter(q_with_no_filter) - ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries) # Assert assert can_filter == False @@ -30,7 +30,7 @@ def test_word_exclude_filter(tmp_path): # Act can_filter = word_filter.can_filter(q_with_exclude_filter) - ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries) # Assert assert can_filter == True @@ -46,7 +46,7 @@ def test_word_include_filter(tmp_path): # Act can_filter = word_filter.can_filter(query_with_include_filter) - ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries) # Assert assert can_filter == True @@ -62,7 +62,7 @@ def test_word_include_and_exclude_filter(tmp_path): # Act can_filter = word_filter.can_filter(query_with_include_and_exclude_filter) - ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) + ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries) # Assert assert can_filter == True