mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Improve search speed. Only apply filter if filter keywords in query
- Formalize filters into class with can_filter() and filter() methods - Use can_filter() method to decide whether to apply filter and create deep copies of entries and embeddings for it - Improve search speed for queries with no filters as deep copying entries, embeddings takes the most time after cross-encodes scoring when calling the /search API Earlier we would create deep copies of entries, embeddings even if the query did not contain any filter keywords
This commit is contained in:
12
src/main.py
12
src/main.py
@@ -21,8 +21,8 @@ from src.utils.cli import cli
|
||||
from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, ConversationProcessorConfigModel
|
||||
from src.utils.rawconfig import FullConfig
|
||||
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
|
||||
from src.search_filter.explicit_filter import explicit_filter
|
||||
from src.search_filter.date_filter import date_filter
|
||||
from src.search_filter.explicit_filter import ExplicitFilter
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
|
||||
# Application Global State
|
||||
config = FullConfig()
|
||||
@@ -72,7 +72,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||
if (t == SearchType.Org or t == None) and model.orgmode_search:
|
||||
# query org-mode notes
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, model.orgmode_search, device=device, filters=[explicit_filter, date_filter], verbose=verbose)
|
||||
hits, entries = text_search.query(user_query, model.orgmode_search, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
@@ -83,7 +83,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||
if (t == SearchType.Music or t == None) and model.music_search:
|
||||
# query music library
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[explicit_filter, date_filter], verbose=verbose)
|
||||
hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
@@ -94,7 +94,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||
if (t == SearchType.Markdown or t == None) and model.orgmode_search:
|
||||
# query markdown files
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, model.markdown_search, device=device, filters=[explicit_filter, date_filter], verbose=verbose)
|
||||
hits, entries = text_search.query(user_query, model.markdown_search, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
@@ -105,7 +105,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||
if (t == SearchType.Ledger or t == None) and model.ledger_search:
|
||||
# query transactions
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, model.ledger_search, filters=[explicit_filter, date_filter], verbose=verbose)
|
||||
hits, entries = text_search.query(user_query, model.ledger_search, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
||||
@@ -9,138 +9,143 @@ import torch
|
||||
import dateparser as dtparse
|
||||
|
||||
|
||||
# Date Range Filter Regexes
|
||||
# Example filter queries:
|
||||
# - dt>="yesterday" dt<"tomorrow"
|
||||
# - dt>="last week"
|
||||
# - dt:"2 years ago"
|
||||
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
||||
class DateFilter:
|
||||
# Date Range Filter Regexes
|
||||
# Example filter queries:
|
||||
# - dt>="yesterday" dt<"tomorrow"
|
||||
# - dt>="last week"
|
||||
# - dt:"2 years ago"
|
||||
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains date filters"
|
||||
return self.extract_date_range(raw_query) is not None
|
||||
|
||||
|
||||
def date_filter(query, entries, embeddings, entry_key='raw'):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
query_daterange = extract_date_range(query)
|
||||
def filter(self, query, entries, embeddings, entry_key='raw'):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
query_daterange = self.extract_date_range(query)
|
||||
|
||||
# if no date in query, return all entries
|
||||
if query_daterange is None:
|
||||
return query, entries, embeddings
|
||||
|
||||
# remove date range filter from query
|
||||
query = re.sub(f'\s+{self.date_regex}', ' ', query)
|
||||
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
|
||||
|
||||
# find entries containing any dates that fall with date range specified in query
|
||||
entries_to_include = set()
|
||||
for id, entry in enumerate(entries):
|
||||
# Extract dates from entry
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
|
||||
# Convert date string in entry to unix timestamp
|
||||
try:
|
||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||
except ValueError:
|
||||
continue
|
||||
# Check if date in entry is within date range specified in query
|
||||
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
||||
entries_to_include.add(id)
|
||||
break
|
||||
|
||||
# delete entries (and their embeddings) marked for exclusion
|
||||
entries_to_exclude = set(range(len(entries))) - entries_to_include
|
||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||
del entries[id]
|
||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||
|
||||
# if no date in query, return all entries
|
||||
if query_daterange is None:
|
||||
return query, entries, embeddings
|
||||
|
||||
# remove date range filter from query
|
||||
query = re.sub(f'\s+{date_regex}', ' ', query)
|
||||
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
|
||||
|
||||
# find entries containing any dates that fall with date range specified in query
|
||||
entries_to_include = set()
|
||||
for id, entry in enumerate(entries):
|
||||
# Extract dates from entry
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
|
||||
# Convert date string in entry to unix timestamp
|
||||
try:
|
||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||
except ValueError:
|
||||
continue
|
||||
# Check if date in entry is within date range specified in query
|
||||
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
||||
entries_to_include.add(id)
|
||||
break
|
||||
def extract_date_range(self, query):
|
||||
# find date range filter in query
|
||||
date_range_matches = re.findall(self.date_regex, query)
|
||||
|
||||
# delete entries (and their embeddings) marked for exclusion
|
||||
entries_to_exclude = set(range(len(entries))) - entries_to_include
|
||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||
del entries[id]
|
||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||
if len(date_range_matches) == 0:
|
||||
return None
|
||||
|
||||
return query, entries, embeddings
|
||||
# extract, parse natural dates ranges from date range filter passed in query
|
||||
# e.g today maps to (start_of_day, start_of_tomorrow)
|
||||
date_ranges_from_filter = []
|
||||
for (cmp, date_str) in date_range_matches:
|
||||
if self.parse(date_str):
|
||||
dt_start, dt_end = self.parse(date_str)
|
||||
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
|
||||
|
||||
# Combine dates with their comparators to form date range intervals
|
||||
# For e.g
|
||||
# >=yesterday maps to [start_of_yesterday, inf)
|
||||
# <tomorrow maps to [0, start_of_tomorrow)
|
||||
# ---
|
||||
effective_date_range = [0, inf]
|
||||
date_range_considering_comparator = []
|
||||
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
|
||||
if cmp == '>':
|
||||
date_range_considering_comparator += [[dtrange_end, inf]]
|
||||
elif cmp == '>=':
|
||||
date_range_considering_comparator += [[dtrange_start, inf]]
|
||||
elif cmp == '<':
|
||||
date_range_considering_comparator += [[0, dtrange_start]]
|
||||
elif cmp == '<=':
|
||||
date_range_considering_comparator += [[0, dtrange_end]]
|
||||
elif cmp == '=' or cmp == ':' or cmp == '==':
|
||||
date_range_considering_comparator += [[dtrange_start, dtrange_end]]
|
||||
|
||||
# Combine above intervals (via AND/intersect)
|
||||
# In the above example, this gives us [start_of_yesterday, start_of_tomorrow)
|
||||
# This is the effective date range to filter entries by
|
||||
# ---
|
||||
for date_range in date_range_considering_comparator:
|
||||
effective_date_range = [
|
||||
max(effective_date_range[0], date_range[0]),
|
||||
min(effective_date_range[1], date_range[1])]
|
||||
|
||||
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
|
||||
return None
|
||||
else:
|
||||
return effective_date_range
|
||||
|
||||
|
||||
def extract_date_range(query):
|
||||
# find date range filter in query
|
||||
date_range_matches = re.findall(date_regex, query)
|
||||
def parse(self, date_str, relative_base=None):
|
||||
"Parse date string passed in date filter of query to datetime object"
|
||||
# clean date string to handle future date parsing by date parser
|
||||
future_strings = ['later', 'from now', 'from today']
|
||||
prefer_dates_from = {True: 'future', False: 'past'}[any([True for fstr in future_strings if fstr in date_str])]
|
||||
clean_date_str = re.sub('|'.join(future_strings), '', date_str)
|
||||
|
||||
if len(date_range_matches) == 0:
|
||||
return None
|
||||
# parse date passed in query date filter
|
||||
parsed_date = dtparse.parse(
|
||||
clean_date_str,
|
||||
settings= {
|
||||
'RELATIVE_BASE': relative_base or datetime.now(),
|
||||
'PREFER_DAY_OF_MONTH': 'first',
|
||||
'PREFER_DATES_FROM': prefer_dates_from
|
||||
})
|
||||
|
||||
# extract, parse natural dates ranges from date range filter passed in query
|
||||
# e.g today maps to (start_of_day, start_of_tomorrow)
|
||||
date_ranges_from_filter = []
|
||||
for (cmp, date_str) in date_range_matches:
|
||||
if parse(date_str):
|
||||
dt_start, dt_end = parse(date_str)
|
||||
date_ranges_from_filter += [[cmp, (dt_start.timestamp(), dt_end.timestamp())]]
|
||||
if parsed_date is None:
|
||||
return None
|
||||
|
||||
# Combine dates with their comparators to form date range intervals
|
||||
# For e.g
|
||||
# >=yesterday maps to [start_of_yesterday, inf)
|
||||
# <tomorrow maps to [0, start_of_tomorrow)
|
||||
# ---
|
||||
effective_date_range = [0, inf]
|
||||
date_range_considering_comparator = []
|
||||
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
|
||||
if cmp == '>':
|
||||
date_range_considering_comparator += [[dtrange_end, inf]]
|
||||
elif cmp == '>=':
|
||||
date_range_considering_comparator += [[dtrange_start, inf]]
|
||||
elif cmp == '<':
|
||||
date_range_considering_comparator += [[0, dtrange_start]]
|
||||
elif cmp == '<=':
|
||||
date_range_considering_comparator += [[0, dtrange_end]]
|
||||
elif cmp == '=' or cmp == ':' or cmp == '==':
|
||||
date_range_considering_comparator += [[dtrange_start, dtrange_end]]
|
||||
|
||||
# Combine above intervals (via AND/intersect)
|
||||
# In the above example, this gives us [start_of_yesterday, start_of_tomorrow)
|
||||
# This is the effective date range to filter entries by
|
||||
# ---
|
||||
for date_range in date_range_considering_comparator:
|
||||
effective_date_range = [
|
||||
max(effective_date_range[0], date_range[0]),
|
||||
min(effective_date_range[1], date_range[1])]
|
||||
|
||||
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
|
||||
return None
|
||||
else:
|
||||
return effective_date_range
|
||||
return self.date_to_daterange(parsed_date, date_str)
|
||||
|
||||
|
||||
def parse(date_str, relative_base=None):
|
||||
"Parse date string passed in date filter of query to datetime object"
|
||||
# clean date string to handle future date parsing by date parser
|
||||
future_strings = ['later', 'from now', 'from today']
|
||||
prefer_dates_from = {True: 'future', False: 'past'}[any([True for fstr in future_strings if fstr in date_str])]
|
||||
clean_date_str = re.sub('|'.join(future_strings), '', date_str)
|
||||
def date_to_daterange(self, parsed_date, date_str):
|
||||
"Convert parsed date to date ranges at natural granularity (day, week, month or year)"
|
||||
|
||||
# parse date passed in query date filter
|
||||
parsed_date = dtparse.parse(
|
||||
clean_date_str,
|
||||
settings= {
|
||||
'RELATIVE_BASE': relative_base or datetime.now(),
|
||||
'PREFER_DAY_OF_MONTH': 'first',
|
||||
'PREFER_DATES_FROM': prefer_dates_from
|
||||
})
|
||||
start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if parsed_date is None:
|
||||
return None
|
||||
|
||||
return date_to_daterange(parsed_date, date_str)
|
||||
|
||||
|
||||
def date_to_daterange(parsed_date, date_str):
|
||||
"Convert parsed date to date ranges at natural granularity (day, week, month or year)"
|
||||
|
||||
start_of_day = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if 'year' in date_str:
|
||||
return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year+1, 1, 1, 0, 0, 0))
|
||||
if 'month' in date_str:
|
||||
start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0)
|
||||
next_month = start_of_month + relativedelta(months=1)
|
||||
return (start_of_month, next_month)
|
||||
if 'week' in date_str:
|
||||
# if week in date string, dateparser parses it to next week start
|
||||
# so today = end of this week
|
||||
start_of_week = start_of_day - timedelta(days=7)
|
||||
return (start_of_week, start_of_day)
|
||||
else:
|
||||
next_day = start_of_day + relativedelta(days=1)
|
||||
if 'year' in date_str:
|
||||
return (datetime(parsed_date.year, 1, 1, 0, 0, 0), datetime(parsed_date.year+1, 1, 1, 0, 0, 0))
|
||||
if 'month' in date_str:
|
||||
start_of_month = datetime(parsed_date.year, parsed_date.month, 1, 0, 0, 0)
|
||||
next_month = start_of_month + relativedelta(months=1)
|
||||
return (start_of_month, next_month)
|
||||
if 'week' in date_str:
|
||||
# if week in date string, dateparser parses it to next week start
|
||||
# so today = end of this week
|
||||
start_of_week = start_of_day - timedelta(days=7)
|
||||
return (start_of_week, start_of_day)
|
||||
else:
|
||||
next_day = start_of_day + relativedelta(days=1)
|
||||
return (start_of_day, next_day)
|
||||
|
||||
@@ -5,42 +5,53 @@ import re
|
||||
import torch
|
||||
|
||||
|
||||
def explicit_filter(raw_query, entries, embeddings, entry_key='raw'):
|
||||
# Separate natural query from explicit required, blocked words filters
|
||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||
class ExplicitFilter:
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains explicit filters"
|
||||
# Extract explicit query portion with required, blocked words to filter from natural query
|
||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||
|
||||
return len(required_words) != 0 or len(blocked_words) != 0
|
||||
|
||||
|
||||
def filter(self, raw_query, entries, embeddings, entry_key='raw'):
|
||||
"Find entries containing required and not blocked words specified in query"
|
||||
# Separate natural query from explicit required, blocked words filters
|
||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return query, entries, embeddings
|
||||
|
||||
# convert each entry to a set of words
|
||||
# split on fullstop, comma, colon, tab, newline or any brackets
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
||||
entries_by_word_set = [set(word.lower()
|
||||
for word
|
||||
in re.split(entry_splitter, entry[entry_key])
|
||||
if word != "")
|
||||
for entry in entries]
|
||||
|
||||
# track id of entries to exclude
|
||||
entries_to_exclude = set()
|
||||
|
||||
# mark entries that do not contain all required_words for exclusion
|
||||
if len(required_words) > 0:
|
||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||
if not required_words.issubset(words_in_entry):
|
||||
entries_to_exclude.add(id)
|
||||
|
||||
# mark entries that contain any blocked_words for exclusion
|
||||
if len(blocked_words) > 0:
|
||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||
if words_in_entry.intersection(blocked_words):
|
||||
entries_to_exclude.add(id)
|
||||
|
||||
# delete entries (and their embeddings) marked for exclusion
|
||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||
del entries[id]
|
||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return query, entries, embeddings
|
||||
|
||||
# convert each entry to a set of words
|
||||
# split on fullstop, comma, colon, tab, newline or any brackets
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
||||
entries_by_word_set = [set(word.lower()
|
||||
for word
|
||||
in re.split(entry_splitter, entry[entry_key])
|
||||
if word != "")
|
||||
for entry in entries]
|
||||
|
||||
# track id of entries to exclude
|
||||
entries_to_exclude = set()
|
||||
|
||||
# mark entries that do not contain all required_words for exclusion
|
||||
if len(required_words) > 0:
|
||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||
if not required_words.issubset(words_in_entry):
|
||||
entries_to_exclude.add(id)
|
||||
|
||||
# mark entries that contain any blocked_words for exclusion
|
||||
if len(blocked_words) > 0:
|
||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||
if words_in_entry.intersection(blocked_words):
|
||||
entries_to_exclude.add(id)
|
||||
|
||||
# delete entries (and their embeddings) marked for exclusion
|
||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||
del entries[id]
|
||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||
|
||||
return query, entries, embeddings
|
||||
@@ -65,25 +65,32 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = [], verbose=0):
|
||||
"Search for entries that answer the query"
|
||||
# Copy original embeddings, entries to filter them for query
|
||||
start = time.time()
|
||||
query = raw_query
|
||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||
entries = deepcopy(model.entries)
|
||||
|
||||
# Use deep copy of original embeddings, entries to filter if query contains filters
|
||||
start = time.time()
|
||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||
if filters_in_query:
|
||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||
entries = deepcopy(model.entries)
|
||||
else:
|
||||
corpus_embeddings = model.corpus_embeddings
|
||||
entries = model.entries
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Copy Time: {end - start:.3f} seconds")
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
start = time.time()
|
||||
for filter in filters:
|
||||
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings)
|
||||
if entries is None or len(entries) == 0:
|
||||
return [], []
|
||||
for filter in filters_in_query:
|
||||
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Filter Time: {end - start:.3f} seconds")
|
||||
|
||||
if entries is None or len(entries) == 0:
|
||||
return [], []
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
start = time.time()
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
|
||||
|
||||
@@ -7,7 +7,7 @@ from math import inf
|
||||
import torch
|
||||
|
||||
# Application Packages
|
||||
from src.search_filter import date_filter
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
|
||||
|
||||
def test_date_filter():
|
||||
@@ -18,99 +18,99 @@ def test_date_filter():
|
||||
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
|
||||
|
||||
q_with_no_date_filter = 'head tail'
|
||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_no_date_filter, entries.copy(), embeddings)
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings)
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 3
|
||||
assert ret_entries == entries
|
||||
|
||||
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 0
|
||||
assert ret_entries == []
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[2]]
|
||||
assert len(ret_emb) == 1
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[1]]
|
||||
assert len(ret_emb) == 1
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[2]]
|
||||
assert len(ret_emb) == 1
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[1], entries[2]]
|
||||
assert len(ret_emb) == 2
|
||||
|
||||
|
||||
def test_extract_date_range():
|
||||
assert date_filter.extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()]
|
||||
assert date_filter.extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||
assert date_filter.extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
|
||||
assert date_filter.extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||
assert DateFilter().extract_date_range('head dt>"1984-01-04" dt<"1984-01-07" tail') == [datetime(1984, 1, 5, 0, 0, 0).timestamp(), datetime(1984, 1, 7, 0, 0, 0).timestamp()]
|
||||
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
|
||||
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||
|
||||
# Unparseable date filter specified in query
|
||||
assert date_filter.extract_date_range('head dt:"Summer of 69" tail') == None
|
||||
assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None
|
||||
|
||||
# No date filter specified in query
|
||||
assert date_filter.extract_date_range('head tail') == None
|
||||
assert DateFilter().extract_date_range('head tail') == None
|
||||
|
||||
# Non intersecting date ranges
|
||||
assert date_filter.extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None
|
||||
assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None
|
||||
|
||||
|
||||
def test_parse():
|
||||
test_now = datetime(1984, 4, 1, 21, 21, 21)
|
||||
|
||||
# day variations
|
||||
assert date_filter.parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0))
|
||||
assert date_filter.parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0))
|
||||
assert date_filter.parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0))
|
||||
assert date_filter.parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0))
|
||||
assert DateFilter().parse('today', relative_base=test_now) == (datetime(1984, 4, 1, 0, 0, 0), datetime(1984, 4, 2, 0, 0, 0))
|
||||
assert DateFilter().parse('tomorrow', relative_base=test_now) == (datetime(1984, 4, 2, 0, 0, 0), datetime(1984, 4, 3, 0, 0, 0))
|
||||
assert DateFilter().parse('yesterday', relative_base=test_now) == (datetime(1984, 3, 31, 0, 0, 0), datetime(1984, 4, 1, 0, 0, 0))
|
||||
assert DateFilter().parse('5 days ago', relative_base=test_now) == (datetime(1984, 3, 27, 0, 0, 0), datetime(1984, 3, 28, 0, 0, 0))
|
||||
|
||||
# week variations
|
||||
assert date_filter.parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0))
|
||||
assert date_filter.parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0))
|
||||
assert DateFilter().parse('last week', relative_base=test_now) == (datetime(1984, 3, 18, 0, 0, 0), datetime(1984, 3, 25, 0, 0, 0))
|
||||
assert DateFilter().parse('2 weeks ago', relative_base=test_now) == (datetime(1984, 3, 11, 0, 0, 0), datetime(1984, 3, 18, 0, 0, 0))
|
||||
|
||||
# month variations
|
||||
assert date_filter.parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0))
|
||||
assert date_filter.parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0))
|
||||
assert DateFilter().parse('next month', relative_base=test_now) == (datetime(1984, 5, 1, 0, 0, 0), datetime(1984, 6, 1, 0, 0, 0))
|
||||
assert DateFilter().parse('2 months ago', relative_base=test_now) == (datetime(1984, 2, 1, 0, 0, 0), datetime(1984, 3, 1, 0, 0, 0))
|
||||
|
||||
# year variations
|
||||
assert date_filter.parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0))
|
||||
assert date_filter.parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0))
|
||||
assert DateFilter().parse('this year', relative_base=test_now) == (datetime(1984, 1, 1, 0, 0, 0), datetime(1985, 1, 1, 0, 0, 0))
|
||||
assert DateFilter().parse('20 years later', relative_base=test_now) == (datetime(2004, 1, 1, 0, 0, 0), datetime(2005, 1, 1, 0, 0, 0))
|
||||
|
||||
# specific month/date variation
|
||||
assert date_filter.parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
|
||||
assert date_filter.parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
|
||||
assert DateFilter().parse('in august', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
|
||||
assert DateFilter().parse('on 1983-08-01', relative_base=test_now) == (datetime(1983, 8, 1, 0, 0, 0), datetime(1983, 8, 2, 0, 0, 0))
|
||||
|
||||
|
||||
def test_date_filter_regex():
|
||||
dtrange_match = re.findall(date_filter.date_regex, 'multi word head dt>"today" dt:"1984-01-01"')
|
||||
dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>"today" dt:"1984-01-01"')
|
||||
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
|
||||
|
||||
dtrange_match = re.findall(date_filter.date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail')
|
||||
dtrange_match = re.findall(DateFilter().date_regex, 'head dt>"today" dt:"1984-01-01" multi word tail')
|
||||
assert dtrange_match == [('>', 'today'), (':', '1984-01-01')]
|
||||
|
||||
dtrange_match = re.findall(date_filter.date_regex, 'multi word head dt>="today" dt="1984-01-01"')
|
||||
dtrange_match = re.findall(DateFilter().date_regex, 'multi word head dt>="today" dt="1984-01-01"')
|
||||
assert dtrange_match == [('>=', 'today'), ('=', '1984-01-01')]
|
||||
|
||||
dtrange_match = re.findall(date_filter.date_regex, 'dt<"multi word date" multi word tail')
|
||||
dtrange_match = re.findall(DateFilter().date_regex, 'dt<"multi word date" multi word tail')
|
||||
assert dtrange_match == [('<', 'multi word date')]
|
||||
|
||||
dtrange_match = re.findall(date_filter.date_regex, 'head dt<="multi word date"')
|
||||
dtrange_match = re.findall(DateFilter().date_regex, 'head dt<="multi word date"')
|
||||
assert dtrange_match == [('<=', 'multi word date')]
|
||||
|
||||
dtrange_match = re.findall(date_filter.date_regex, 'head tail')
|
||||
dtrange_match = re.findall(DateFilter().date_regex, 'head tail')
|
||||
assert dtrange_match == []
|
||||
Reference in New Issue
Block a user