Use regexes to check if any explicit filters in query. Test can_filter

This commit is contained in:
Debanjum Singh Solanky
2022-09-03 23:47:28 +03:00
parent 546fad570d
commit 858d86075b
2 changed files with 10 additions and 2 deletions

View File

@@ -52,8 +52,8 @@ class ExplicitFilter:
def can_filter(self, raw_query):
"Check if query contains explicit filters"
# Extract explicit query portion with required, blocked words to filter from natural query
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
required_words = re.findall(self.required_regex, raw_query)
blocked_words = re.findall(self.blocked_regex, raw_query)
return len(required_words) != 0 or len(blocked_words) != 0

View File

@@ -13,9 +13,11 @@ def test_no_explicit_filter(tmp_path):
q_with_no_filter = 'head tail'
# Act
can_filter = explicit_filter.can_filter(q_with_no_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings)
# Assert
assert can_filter == False
assert ret_query == 'head tail'
assert len(ret_emb) == 4
assert ret_entries == entries
@@ -28,9 +30,11 @@ def test_explicit_exclude_filter(tmp_path):
q_with_exclude_filter = 'head -exclude_word tail'
# Act
can_filter = explicit_filter.can_filter(q_with_exclude_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert len(ret_emb) == 2
assert ret_entries == [entries[0], entries[2]]
@@ -43,9 +47,11 @@ def test_explicit_include_filter(tmp_path):
query_with_include_filter = 'head +include_word tail'
# Act
can_filter = explicit_filter.can_filter(query_with_include_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert len(ret_emb) == 2
assert ret_entries == [entries[2], entries[3]]
@@ -58,9 +64,11 @@ def test_explicit_include_and_exclude_filter(tmp_path):
query_with_include_and_exclude_filter = 'head +include_word -exclude_word tail'
# Act
can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert len(ret_emb) == 1
assert ret_entries == [entries[2]]