Improve search speed. Only apply filter if filter keywords in query

- Formalize filters into class with can_filter() and filter() methods

- Use can_filter() method to decide whether to apply filter and
  create deep copies of entries and embeddings for it

- Improve search speed for queries with no filters
  as deep copying entries, embeddings takes the most time
  after cross-encodes scoring when calling the /search API

  Earlier we would create deep copies of entries, embeddings
  even if the query did not contain any filter keywords
This commit is contained in:
Debanjum Singh Solanky
2022-07-26 22:47:26 +04:00
parent f094c86204
commit b1e64fd4a8
5 changed files with 223 additions and 200 deletions

View File

@@ -21,8 +21,8 @@ from src.utils.cli import cli
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
from src.utils.rawconfig import FullConfig from src.utils.rawconfig import FullConfig
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from src.search_filter.explicit_filter import explicit_filter from src.search_filter.explicit_filter import ExplicitFilter
from src.search_filter.date_filter import date_filter from src.search_filter.date_filter import DateFilter
# Application Global State # Application Global State
config = FullConfig() config = FullConfig()
@@ -72,7 +72,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Org or t == None) and model.orgmode_search: if (t == SearchType.Org or t == None) and model.orgmode_search:
# query org-mode notes # query org-mode notes
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, model.orgmode_search, device=device, filters=[explicit_filter, date_filter], verbose=verbose) hits, entries = text_search.query(user_query, model.orgmode_search, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@@ -83,7 +83,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Music or t == None) and model.music_search: if (t == SearchType.Music or t == None) and model.music_search:
# query music library # query music library
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[explicit_filter, date_filter], verbose=verbose) hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@@ -94,7 +94,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Markdown or t == None) and model.orgmode_search: if (t == SearchType.Markdown or t == None) and model.orgmode_search:
# query markdown files # query markdown files
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, model.markdown_search, device=device, filters=[explicit_filter, date_filter], verbose=verbose) hits, entries = text_search.query(user_query, model.markdown_search, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@@ -105,7 +105,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Ledger or t == None) and model.ledger_search: if (t == SearchType.Ledger or t == None) and model.ledger_search:
# query transactions # query transactions
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, model.ledger_search, filters=[explicit_filter, date_filter], verbose=verbose) hits, entries = text_search.query(user_query, model.ledger_search, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results

View File

@@ -9,6 +9,7 @@ import torch
import dateparser as dtparse import dateparser as dtparse
class DateFilter:
# Date Range Filter Regexes # Date Range Filter Regexes
# Example filter queries: # Example filter queries:
# - dt>="yesterday" dt<"tomorrow" # - dt>="yesterday" dt<"tomorrow"
@@ -16,18 +17,22 @@ import dateparser as dtparse
# - dt:"2 years ago" # - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})\"(.*?)\"" date_regex = r"dt([:><=]{1,2})\"(.*?)\""
def can_filter(self, raw_query):
"Check if query contains date filters"
return self.extract_date_range(raw_query) is not None
def date_filter(query, entries, embeddings, entry_key='raw'):
def filter(self, query, entries, embeddings, entry_key='raw'):
"Find entries containing any dates that fall within date range specified in query" "Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query # extract date range specified in date filter of query
query_daterange = extract_date_range(query) query_daterange = self.extract_date_range(query)
# if no date in query, return all entries # if no date in query, return all entries
if query_daterange is None: if query_daterange is None:
return query, entries, embeddings return query, entries, embeddings
# remove date range filter from query # remove date range filter from query
query = re.sub(f'\s+{date_regex}', ' ', query) query = re.sub(f'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# find entries containing any dates that fall with date range specified in query # find entries containing any dates that fall with date range specified in query
@@ -54,9 +59,9 @@ def date_filter(query, entries, embeddings, entry_key='raw'):
return query, entries, embeddings return query, entries, embeddings
def extract_date_range(query): def extract_date_range(self, query):
# find date range filter in query # find date range filter in query
date_range_matches = re.findall(date_regex, query) date_range_matches = re.findall(self.date_regex, query)
if len(date_range_matches) == 0: if len(date_range_matches) == 0:
return None return None
@@ -65,8 +70,8 @@ def extract_date_range(query):
# e.g today maps to (start_of_day, start_of_tomorrow) # e.g today maps to (start_of_day, start_of_tomorrow)
date_ranges_from_filter = [] date_ranges_from_filter = []
for (cmp, date_str) in date_range_matches: for (cmp, date_str) in date_range_matches:
if parse(date_str): if self.parse(date_str):
dt_start, dt_end = parse(date_str) dt_start, dt_end = self.parse(date_str)
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]] date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
# Combine dates with their comparators to form date range intervals # Combine dates with their comparators to form date range intervals
@@ -103,7 +108,7 @@ def extract_date_range(query):
return effective_date_range return effective_date_range
def parse(date_str, relative_base=None): def parse(self, date_str, relative_base=None):
"Parse date string passed in date filter of query to datetime object" "Parse date string passed in date filter of query to datetime object"
# clean date string to handle future date parsing by date parser # clean date string to handle future date parsing by date parser
future_strings = ['later', 'from now', 'from today'] future_strings = ['later', 'from now', 'from today']
@@ -122,10 +127,10 @@ def parse(date_str, relative_base=None):
if parsed_date is None: if parsed_date is None:
return None return None
return date_to_daterange(parsed_date, date_str) return self.date_to_daterange(parsed_date, date_str)
def date_to_daterange(parsed_date, date_str): def date_to_daterange(self, parsed_date, date_str):
"Convert parsed date to date ranges at natural granularity (day, week, month or year)" "Convert parsed date to date ranges at natural granularity (day, week, month or year)"
start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0) start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)

View File

@@ -5,7 +5,18 @@ import re
import torch import torch
def explicit_filter(raw_query, entries, embeddings, entry_key='raw'): class ExplicitFilter:
def can_filter(self, raw_query):
"Check if query contains explicit filters"
# Extract explicit query portion with required, blocked words to filter from natural query
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
return len(required_words) != 0 or len(blocked_words) != 0
def filter(self, raw_query, entries, embeddings, entry_key='raw'):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from explicit required, blocked words filters # Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])

View File

@@ -65,25 +65,32 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = [], verbose=0): def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = [], verbose=0):
"Search for entries that answer the query" "Search for entries that answer the query"
# Copy original embeddings, entries to filter them for query
start = time.time()
query = raw_query query = raw_query
# Use deep copy of original embeddings, entries to filter if query contains filters
start = time.time()
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
if filters_in_query:
corpus_embeddings = deepcopy(model.corpus_embeddings) corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries) entries = deepcopy(model.entries)
else:
corpus_embeddings = model.corpus_embeddings
entries = model.entries
end = time.time() end = time.time()
if verbose > 1: if verbose > 1:
print(f"Copy Time: {end - start:.3f} seconds") print(f"Copy Time: {end - start:.3f} seconds")
# Filter query, entries and embeddings before semantic search # Filter query, entries and embeddings before semantic search
start = time.time() start = time.time()
for filter in filters: for filter in filters_in_query:
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings) query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
if entries is None or len(entries) == 0:
return [], []
end = time.time() end = time.time()
if verbose > 1: if verbose > 1:
print(f"Filter Time: {end - start:.3f} seconds") print(f"Filter Time: {end - start:.3f} seconds")
if entries is None or len(entries) == 0:
return [], []
# Encode the query using the bi-encoder # Encode the query using the bi-encoder
start = time.time() start = time.time()
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True) question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)

View File

@@ -7,7 +7,7 @@ from math import inf
import torch import torch
# Application Packages # Application Packages
from src.search_filter import date_filter from src.search_filter.date_filter import DateFilter
def test_date_filter(): def test_date_filter():
@@ -18,99 +18,99 @@ def test_date_filter():
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}] {'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
q_with_no_date_filter = 'head tail' q_with_no_date_filter = 'head tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_no_date_filter, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 3 assert len(ret_emb) == 3
assert ret_entries == entries assert ret_entries == entries
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 0 assert len(ret_emb) == 0
assert ret_entries == [] assert ret_entries == []
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[2]] assert ret_entries == [entries[2]]
assert len(ret_emb) == 1 assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[1]] assert ret_entries == [entries[1]]
assert len(ret_emb) == 1 assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[2]] assert ret_entries == [entries[2]]
assert len(ret_emb) == 1 assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[1], entries[2]] assert ret_entries == [entries[1], entries[2]]
assert len(ret_emb) == 2 assert len(ret_emb) == 2
def test_extract_date_range(): def test_extract_date_range():
assert date_filter.extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()]
assert date_filter.extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert date_filter.extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf] assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
assert date_filter.extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()] assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()]
# Unparseable date filter specified in query # Unparseable date filter specified in query
assert date_filter.extract_date_range('head dt:"Summer of 69" tail') == None assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None
# No date filter specified in query # No date filter specified in query
assert date_filter.extract_date_range('head tail') == None assert DateFilter().extract_date_range('head tail') == None
# Non intersecting date ranges # Non intersecting date ranges
assert date_filter.extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None
def test_parse(): def test_parse():
test_now = datetime(1984, 4, 1, 21, 21, 21) test_now = datetime(1984, 4, 1, 21, 21, 21)
# day variations # day variations
assert date_filter.parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0)) assert DateFilter().parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0))
assert date_filter.parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0)) assert DateFilter().parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0))
assert date_filter.parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0)) assert DateFilter().parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0))
assert date_filter.parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0)) assert DateFilter().parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0))
# week variations # week variations
assert date_filter.parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0)) assert DateFilter().parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0))
assert date_filter.parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0)) assert DateFilter().parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0))
# month variations # month variations
assert date_filter.parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0)) assert DateFilter().parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0))
assert date_filter.parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0)) assert DateFilter().parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0))
# year variations # year variations
assert date_filter.parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0)) assert DateFilter().parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0))
assert date_filter.parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0)) assert DateFilter().parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0))
# specific month/date variation # specific month/date variation
assert date_filter.parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0)) assert DateFilter().parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
assert date_filter.parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0)) assert DateFilter().parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
def test_date_filter_regex(): def test_date_filter_regex():
dtrange_match = re.findall(date_filter.date_regex, 'multi word head dt>"today" dt:"1984-01-01"') dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>"today" dt:"1984-01-01"')
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')] assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
dtrange_match = re.findall(date_filter.date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail') dtrange_match = re.findall(DateFilter().date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail')
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')] assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
dtrange_match = re.findall(date_filter.date_regex, 'multi word head dt>="today" dt="1984-01-01"') dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>="today" dt="1984-01-01"')
assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')] assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')]
dtrange_match = re.findall(date_filter.date_regex, 'dt<"multi word date" multi word tail') dtrange_match = re.findall(DateFilter().date_regex, 'dt<"multi word date" multi word tail')
assert dtrange_match == [('<', 'multi word date')] assert dtrange_match == [('<', 'multi word date')]
dtrange_match = re.findall(date_filter.date_regex, 'head dt<="multi word date"') dtrange_match = re.findall(DateFilter().date_regex, 'head dt<="multi word date"')
assert dtrange_match == [('<=', 'multi word date')] assert dtrange_match == [('<=', 'multi word date')]
dtrange_match = re.findall(date_filter.date_regex, 'head tail') dtrange_match = re.findall(DateFilter().date_regex, 'head tail')
assert dtrange_match == [] assert dtrange_match == []