From 858d86075b8126fe26b4f8de725ed462ddc6fdee Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 3 Sep 2022 23:47:28 +0300 Subject: [PATCH] Use regexes to check if any explicit filters in query. Test can_filter --- src/search_filter/explicit_filter.py | 4 ++-- tests/test_explicit_filter.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index 797c007d..9d043a4d 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -52,8 +52,8 @@ class ExplicitFilter: def can_filter(self, raw_query): "Check if query contains explicit filters" # Extract explicit query portion with required, blocked words to filter from natural query - required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) - blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + required_words = re.findall(self.required_regex, raw_query) + blocked_words = re.findall(self.blocked_regex, raw_query) return len(required_words) != 0 or len(blocked_words) != 0 diff --git a/tests/test_explicit_filter.py b/tests/test_explicit_filter.py index f3b88659..9d4c022a 100644 --- a/tests/test_explicit_filter.py +++ b/tests/test_explicit_filter.py @@ -13,9 +13,11 @@ def test_no_explicit_filter(tmp_path): q_with_no_filter = 'head tail' # Act + can_filter = explicit_filter.can_filter(q_with_no_filter) ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert + assert can_filter == False assert ret_query == 'head tail' assert len(ret_emb) == 4 assert ret_entries == entries @@ -28,9 +30,11 @@ def test_explicit_exclude_filter(tmp_path): q_with_exclude_filter = 'head -exclude_word tail' # Act + can_filter = explicit_filter.can_filter(q_with_exclude_filter) ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) # Assert + assert can_filter == True assert ret_query == 'head tail' assert len(ret_emb) == 2 assert ret_entries == [entries[0], entries[2]] @@ -43,9 +47,11 @@ def test_explicit_include_filter(tmp_path): query_with_include_filter = 'head +include_word tail' # Act + can_filter = explicit_filter.can_filter(query_with_include_filter) ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings) # Assert + assert can_filter == True assert ret_query == 'head tail' assert len(ret_emb) == 2 assert ret_entries == [entries[2], entries[3]] @@ -58,9 +64,11 @@ def test_explicit_include_and_exclude_filter(tmp_path): query_with_include_and_exclude_filter = 'head +include_word -exclude_word tail' # Act + can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter) ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) # Assert + assert can_filter == True assert ret_query == 'head tail' assert len(ret_emb) == 1 assert ret_entries == [entries[2]]