diff --git a/src/main.py b/src/main.py index 00e68612..b9723879 100644 --- a/src/main.py +++ b/src/main.py @@ -18,6 +18,7 @@ from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, Con from src.utils.rawconfig import FullConfig from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize from src.search_filter.explicit_filter import explicit_filter +from src.search_filter.date_filter import date_filter # Application Global State config = FullConfig() @@ -59,7 +60,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Notes or t == None) and model.notes_search: # query notes - hits, entries = asymmetric.query(user_query, model.notes_search, device=device, filters=[explicit_filter]) + hits, entries = asymmetric.query(user_query, model.notes_search, device=device, filters=[explicit_filter, date_filter]) # collate and return results return asymmetric.collate_results(hits, entries, results_count) diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py new file mode 100644 index 00000000..fc04722e --- /dev/null +++ b/src/search_filter/date_filter.py @@ -0,0 +1,33 @@ +# Standard Packages +import re + +# External Packages +import torch + + +def date_filter(query, entries, embeddings): + # extract date from query + date_regex = r'\d{4}-\d{2}-\d{2}' + dates_in_query = re.findall(date_regex, query) + + # if no date in query, return all entries + if dates_in_query is None or len(dates_in_query) == 0: + return query, entries, embeddings + + # remove dates from query + query = re.sub(date_regex, '', query) + + # find entries with dates from query in them + entries_to_include = set() + for id, entry in enumerate(entries): + for date in dates_in_query: + if date in entry[1]: + entries_to_include.add(id) + + # delete entries (and their embeddings) marked for exclusion + entries_to_exclude = set(range(len(entries))) - entries_to_include + for id in sorted(list(entries_to_exclude), reverse=True): + del entries[id] + embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) + + return query, entries, embeddings \ No newline at end of file