diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index 06cf8878..aec28e01 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -28,7 +28,8 @@ def date_filter(query, entries, embeddings): return query, entries, embeddings # remove date range filter from query - query = re.sub(date_range_regex, '', query) + query = re.sub(f'\s+{date_regex}', ' ', query) + query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces # find entries containing any dates that fall with date range specified in query entries_to_include = set() @@ -38,7 +39,7 @@ def date_filter(query, entries, embeddings): # Convert date string in entry to unix timestamp date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() # Check if date in entry is within date range specified in query - if query_daterange[0] <= date_in_entry <= query_daterange[1]: + if query_daterange[0] <= date_in_entry < query_daterange[1]: entries_to_include.add(id) break diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 272643cf..615eeb5b 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -3,10 +3,57 @@ import re from datetime import datetime from math import inf +# External Packages +import torch + # Application Packages from src.search_filter import date_filter +def test_date_filter(): + embeddings = torch.randn(3, 10) + entries = [ + ['', 'Entry with no date'], + ['', 'April Fools entry: 1984-04-01'], + ['', 'Entry with date:1984-04-02']] + + q_with_no_date_filter = 'head tail' + ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_no_date_filter, entries.copy(), embeddings) + assert ret_query == 'head tail' + assert len(ret_emb) == 3 + assert ret_entries == entries + + q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' + ret_query, ret_entries, ret_emb = date_filter.date_filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) + assert ret_query == 'head tail' + assert len(ret_emb) == 0 + assert ret_entries == [] + + query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' + ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + assert ret_query == 'head tail' + assert ret_entries == [entries[2]] + assert len(ret_emb) == 1 + + query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' + ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + assert ret_query == 'head tail' + assert ret_entries == [entries[1]] + assert len(ret_emb) == 1 + + query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' + ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + assert ret_query == 'head tail' + assert ret_entries == [entries[2]] + assert len(ret_emb) == 1 + + query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' + ret_query, ret_entries, ret_emb = date_filter.date_filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + assert ret_query == 'head tail' + assert ret_entries == [entries[1], entries[2]] + assert len(ret_emb) == 2 + + def test_extract_date_range(): assert date_filter.extract_date_range('head dt>"2020-01-04" dt<"2020-01-07" tail') == [datetime(2020, 1, 5, 0, 0, 0).timestamp(), datetime(2020, 1, 7, 0, 0, 0).timestamp()] assert date_filter.extract_date_range('head dt<="2020-01-01"') == [0, datetime(2020, 1, 2, 0, 0, 0).timestamp()]