Pre-compute entry word sets to improve explicit filter query performance

This commit is contained in:
Debanjum Singh Solanky
2022-09-03 16:01:54 +03:00
parent 094bd18e57
commit c7de57b8ea
6 changed files with 81 additions and 26 deletions

View File

@@ -17,12 +17,21 @@ class DateFilter:
# - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
def __init__(self, entry_key='raw'):
self.entry_key = entry_key
def load(*args, **kwargs):
pass
def can_filter(self, raw_query):
"Check if query contains date filters"
return self.extract_date_range(raw_query) is not None
def filter(self, query, entries, embeddings, entry_key='raw'):
def apply(self, query, entries, embeddings):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
query_daterange = self.extract_date_range(query)
@@ -39,7 +48,7 @@ class DateFilter:
entries_to_include = set()
for id, entry in enumerate(entries):
# Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
# Convert date string in entry to unix timestamp
try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()

View File

@@ -1,11 +1,46 @@
# Standard Packages
import re
import time
import pickle
# External Packages
import torch
# Internal Packages
from src.utils.helpers import resolve_absolute_path
from src.utils.config import SearchType
class ExplicitFilter:
def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'):
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
self.entry_key = entry_key
self.search_type = search_type
def load(self, entries, regenerate=False):
if self.filter_file.exists() and not regenerate:
start = time.time()
with self.filter_file.open('rb') as f:
entries_by_word_set = pickle.load(f)
end = time.time()
print(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
else:
start = time.time()
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
entries_by_word_set = [set(word.lower()
for word
in re.split(entry_splitter, entry[self.entry_key])
if word != "")
for entry in entries]
with self.filter_file.open('wb') as f:
pickle.dump(entries_by_word_set, f)
end = time.time()
print(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
return entries_by_word_set
def can_filter(self, raw_query):
"Check if query contains explicit filters"
# Extract explicit query portion with required, blocked words to filter from natural query
@@ -15,26 +50,24 @@ class ExplicitFilter:
return len(required_words) != 0 or len(blocked_words) != 0
def filter(self, raw_query, entries, embeddings, entry_key='raw'):
def apply(self, raw_query, entries, embeddings):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from explicit required, blocked words filters
start = time.time()
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
end = time.time()
print(f"Time to extract required, blocked words: {end - start} seconds")
if len(required_words) == 0 and len(blocked_words) == 0:
return query, entries, embeddings
# convert each entry to a set of words
# split on fullstop, comma, colon, tab, newline or any brackets
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
entries_by_word_set = [set(word.lower()
for word
in re.split(entry_splitter, entry[entry_key])
if word != "")
for entry in entries]
# load or generate word set for each entry
entries_by_word_set = self.load(entries, regenerate=False)
# track id of entries to exclude
start = time.time()
entries_to_exclude = set()
# mark entries that do not contain all required_words for exclusion
@@ -48,10 +81,15 @@ class ExplicitFilter:
for id, words_in_entry in enumerate(entries_by_word_set):
if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id)
end = time.time()
print(f"Mark entries to filter: {end - start} seconds")
# delete entries (and their embeddings) marked for exclusion
start = time.time()
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
end = time.time()
print(f"Remove entries to filter from embeddings: {end - start} seconds")
return query, entries, embeddings