diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py new file mode 100644 index 00000000..065badc0 --- /dev/null +++ b/src/search_filter/file_filter.py @@ -0,0 +1,37 @@ +# Standard Packages +import re +import fnmatch + +# External Packages +import torch + +# Internal Packages +from src.search_filter.base_filter import BaseFilter + + +class FileFilter(BaseFilter): + file_filter_regex = r'file:"(.+?)" ?' + + def __init__(self, entry_key='file'): + self.entry_key = entry_key + + def load(self, *args, **kwargs): + pass + + def can_filter(self, raw_query): + return re.search(self.file_filter_regex, raw_query) is not None + + def apply(self, raw_query, raw_entries, raw_embeddings): + files_to_search = re.findall(self.file_filter_regex, raw_query) + if not files_to_search: + return raw_query, raw_entries, raw_embeddings + + query = re.sub(self.file_filter_regex, '', raw_query).strip() + included_entry_indices = [id for id, entry in enumerate(raw_entries) for search_file in files_to_search if fnmatch.fnmatch(entry[self.entry_key], search_file)] + if not included_entry_indices: + return query, [], torch.empty(0) + + entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices] + embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) + + return query, entries, embeddings diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py new file mode 100644 index 00000000..401adfc7 --- /dev/null +++ b/tests/test_file_filter.py @@ -0,0 +1,101 @@ +# External Packages +import torch + +# Application Packages +from src.search_filter.file_filter import FileFilter + + +def test_no_file_filter(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == False + assert ret_query == 'head tail' + assert len(ret_emb) == 4 + assert ret_entries == entries + + +def test_file_filter_with_partial_match(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"*.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 4 + assert ret_entries == entries + + +def test_file_filter_with_non_existent_file(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"nonexistent.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 0 + assert ret_entries == [] + + +def test_single_file_filter(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"file1.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 2 + assert ret_entries == [entries[0], entries[2]] + + +def test_multiple_file_filter(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head tail file:"file1.org" file:"file2.org"' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 4 + assert ret_entries == entries + + +def arrange_content(): + embeddings = torch.randn(4, 10) + entries = [ + {'compiled': '', 'raw': 'First Entry', 'file': 'file1.org'}, + {'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'}, + {'compiled': '', 'raw': 'Third Entry', 'file': 'file1.org'}, + {'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}] + + return embeddings, entries