diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index 09580e4a..797c007d 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -16,6 +16,10 @@ logger = logging.getLogger(__name__) class ExplicitFilter: + # Filter Regex + required_regex = r'\+([^\s]+) ?' + blocked_regex = r'\-([^\s]+) ?' + def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'): self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") self.entry_key = entry_key @@ -58,11 +62,13 @@ class ExplicitFilter: "Find entries containing required and not blocked words specified in query" # Separate natural query from explicit required, blocked words filters start = time.time() - query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) - required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) - blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + + required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)]) + blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)]) + query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)) + end = time.time() - logger.debug(f"Time to extract required, blocked words: {end - start} seconds") + logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") if len(required_words) == 0 and len(blocked_words) == 0: return query, entries, embeddings @@ -86,7 +92,7 @@ class ExplicitFilter: if words_in_entry.intersection(blocked_words): entries_to_exclude.add(id) end = time.time() - logger.debug(f"Mark entries to filter: {end - start} seconds") + logger.debug(f"Mark entries not satisfying filter: {end - start} seconds") # delete entries (and their embeddings) marked for exclusion start = time.time() @@ -94,6 +100,6 @@ class ExplicitFilter: del entries[id] embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) end = time.time() - logger.debug(f"Remove entries to filter from embeddings: {end - start} seconds") + logger.debug(f"Delete entries not satisfying filter: {end - start} seconds") return query, entries, embeddings