From c92789d20a2c0e1ae93bca259d6b2edcbd3e6d6a Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 13 Jul 2022 16:07:45 +0400 Subject: [PATCH] Extract explicit pre-search filter function into a separate module Details -- - Move explicit_filters function into separate module under search_filter - Update signature of explicit filter to take and return query, entries, embeddings - Use this explicit_filter func from search_filters module in query Reason -- Abstraction will simplify adding other pre-search filters. E.g datetime filter --- src/search_filter/explicit_filter.py | 46 ++++++++++++++++++++++++++++ src/search_type/asymmetric.py | 45 ++------------------------- 2 files changed, 49 insertions(+), 42 deletions(-) create mode 100644 src/search_filter/explicit_filter.py diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py new file mode 100644 index 00000000..363dbd71 --- /dev/null +++ b/src/search_filter/explicit_filter.py @@ -0,0 +1,46 @@ +# Standard Packages +import re + +# External Packages +import torch + + +def explicit_filter(raw_query, entries, embeddings): + # Separate natural query from explicit required, blocked words filters + query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) + required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) + blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + + if len(required_words) == 0 and len(blocked_words) == 0: + return query, entries, embeddings + + # convert each entry to a set of words + entries_by_word_set = [set(word.lower() + for word + in re.split( + r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets + entry[0]) + if word != "") + for entry in entries] + + # track id of entries to exclude + entries_to_exclude = set() + + # mark entries that do not contain all required_words for exclusion + if len(required_words) > 0: + for id, words_in_entry in enumerate(entries_by_word_set): + if not required_words.issubset(words_in_entry): + entries_to_exclude.add(id) + + # mark entries that contain any blocked_words for exclusion + if len(blocked_words) > 0: + for id, words_in_entry in enumerate(entries_by_word_set): + if words_in_entry.intersection(blocked_words): + entries_to_exclude.add(id) + + # delete entries (and their embeddings) marked for exclusion + for id in sorted(list(entries_to_exclude), reverse=True): + del entries[id] + embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) + + return query, entries, embeddings \ No newline at end of file diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index 611501f7..53d41bf4 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -3,7 +3,6 @@ # Standard Packages import json import gzip -import re import argparse import pathlib from copy import deepcopy @@ -15,6 +14,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util # Internal Packages from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model from src.processor.org_mode.org_to_jsonl import org_to_jsonl +from src.search_filter.explicit_filter import explicit_filter from src.utils.config import TextSearchModel from src.utils.rawconfig import AsymmetricSearchConfig, TextContentConfig from src.utils.constants import empty_escape_sequences @@ -96,17 +96,14 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): "Search all notes for entries that answer the query" - # Separate natural query from explicit required, blocked words filters - query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) - required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) - blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) # Copy original embeddings, entries to filter them for query + query = raw_query corpus_embeddings = deepcopy(model.corpus_embeddings) entries = deepcopy(model.entries) # Filter to entries that contain all required_words and no blocked_words - entries, corpus_embeddings = explicit_filter(entries, corpus_embeddings, required_words, blocked_words) + query, entries, corpus_embeddings = explicit_filter(query, entries, corpus_embeddings) if entries is None or len(entries) == 0: return {} @@ -133,42 +130,6 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): return hits, entries -def explicit_filter(entries, embeddings, required_words, blocked_words): - if len(required_words) == 0 and len(blocked_words) == 0: - return entries, embeddings - - # convert each entry to a set of words - entries_by_word_set = [set(word.lower() - for word - in re.split( - r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets - entry[0]) - if word != "") - for entry in entries] - - # track id of entries to exclude - entries_to_exclude = set() - - # mark entries that do not contain all required_words for exclusion - if len(required_words) > 0: - for id, words_in_entry in enumerate(entries_by_word_set): - if not required_words.issubset(words_in_entry): - entries_to_exclude.add(id) - - # mark entries that contain any blocked_words for exclusion - if len(blocked_words) > 0: - for id, words_in_entry in enumerate(entries_by_word_set): - if words_in_entry.intersection(blocked_words): - entries_to_exclude.add(id) - - # delete entries (and their embeddings) marked for exclusion - for id in sorted(list(entries_to_exclude), reverse=True): - del entries[id] - embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) - - return entries, embeddings - - def render_results(hits, entries, count=5, display_biencoder_results=False): "Render the Results returned by Search for the Query" if display_biencoder_results: