Extract explicit pre-search filter function into a separate module

Details
--
- Move explicit_filters function into separate module under search_filter
- Update signature of explicit filter to take and return query, entries, embeddings
- Use this explicit_filter func from search_filters module in query

Reason
--
Abstraction will simplify adding other pre-search filters. E.g datetime filter
This commit is contained in:
Debanjum Singh Solanky
2022-07-13 16:07:45 +04:00
parent 589bfa9424
commit c92789d20a
2 changed files with 49 additions and 42 deletions

View File

@@ -0,0 +1,46 @@
# Standard Packages
import re
# External Packages
import torch
def explicit_filter(raw_query, entries, embeddings):
# Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
if len(required_words) == 0 and len(blocked_words) == 0:
return query, entries, embeddings
# convert each entry to a set of words
entries_by_word_set = [set(word.lower()
for word
in re.split(
r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets
entry[0])
if word != "")
for entry in entries]
# track id of entries to exclude
entries_to_exclude = set()
# mark entries that do not contain all required_words for exclusion
if len(required_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id)
# mark entries that contain any blocked_words for exclusion
if len(blocked_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id)
# delete entries (and their embeddings) marked for exclusion
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
return query, entries, embeddings

View File

@@ -3,7 +3,6 @@
# Standard Packages
import json
import gzip
import re
import argparse
import pathlib
from copy import deepcopy
@@ -15,6 +14,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.search_filter.explicit_filter import explicit_filter
from src.utils.config import TextSearchModel
from src.utils.rawconfig import AsymmetricSearchConfig, TextContentConfig
from src.utils.constants import empty_escape_sequences
@@ -96,17 +96,14 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')):
"Search all notes for entries that answer the query"
# Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
# Copy original embeddings, entries to filter them for query
query = raw_query
corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries)
# Filter to entries that contain all required_words and no blocked_words
entries, corpus_embeddings = explicit_filter(entries, corpus_embeddings, required_words, blocked_words)
query, entries, corpus_embeddings = explicit_filter(query, entries, corpus_embeddings)
if entries is None or len(entries) == 0:
return {}
@@ -133,42 +130,6 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')):
return hits, entries
def explicit_filter(entries, embeddings, required_words, blocked_words):
if len(required_words) == 0 and len(blocked_words) == 0:
return entries, embeddings
# convert each entry to a set of words
entries_by_word_set = [set(word.lower()
for word
in re.split(
r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets
entry[0])
if word != "")
for entry in entries]
# track id of entries to exclude
entries_to_exclude = set()
# mark entries that do not contain all required_words for exclusion
if len(required_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id)
# mark entries that contain any blocked_words for exclusion
if len(blocked_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id)
# delete entries (and their embeddings) marked for exclusion
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
return entries, embeddings
def render_results(hits, entries, count=5, display_biencoder_results=False):
"Render the Results returned by Search for the Query"
if display_biencoder_results: