diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 065badc0..674d88ed 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -22,10 +22,19 @@ class FileFilter(BaseFilter): return re.search(self.file_filter_regex, raw_query) is not None def apply(self, raw_query, raw_entries, raw_embeddings): - files_to_search = re.findall(self.file_filter_regex, raw_query) - if not files_to_search: + # Extract file filters from raw query + raw_files_to_search = re.findall(self.file_filter_regex, raw_query) + if not raw_files_to_search: return raw_query, raw_entries, raw_embeddings + # Convert simple file filters with no path separator into regex + # e.g. "file:notes.org" -> "file:.*notes.org" + files_to_search = [] + for file in sorted(raw_files_to_search): + if '/' not in file and '\\' not in file and '*' not in file: + files_to_search += [f'*{file}'] + else: + files_to_search += [file] query = re.sub(self.file_filter_regex, '', raw_query).strip() included_entry_indices = [id for id, entry in enumerate(raw_entries) for search_file in files_to_search if fnmatch.fnmatch(entry[self.entry_key], search_file)] if not included_entry_indices: diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index 401adfc7..b15b8a69 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -22,23 +22,6 @@ def test_no_file_filter(): assert ret_entries == entries -def test_file_filter_with_partial_match(): - # Arrange - file_filter = FileFilter() - embeddings, entries = arrange_content() - q_with_no_filter = 'head file:"*.org" tail' - - # Act - can_filter = file_filter.can_filter(q_with_no_filter) - ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) - - # Assert - assert can_filter == True - assert ret_query == 'head tail' - assert len(ret_emb) == 4 - assert ret_entries == entries - - def test_file_filter_with_non_existent_file(): # Arrange file_filter = FileFilter() @@ -60,7 +43,7 @@ def test_single_file_filter(): # Arrange file_filter = FileFilter() embeddings, entries = arrange_content() - q_with_no_filter = 'head file:"file1.org" tail' + q_with_no_filter = 'head file:"file 1.org" tail' # Act can_filter = file_filter.can_filter(q_with_no_filter) @@ -73,11 +56,45 @@ def test_single_file_filter(): assert ret_entries == [entries[0], entries[2]] +def test_file_filter_with_partial_match(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"1.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 2 + assert ret_entries == [entries[0], entries[2]] + + +def test_file_filter_with_regex_match(): + # Arrange + file_filter = FileFilter() + embeddings, entries = arrange_content() + q_with_no_filter = 'head file:"*.org" tail' + + # Act + can_filter = file_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 4 + assert ret_entries == entries + + def test_multiple_file_filter(): # Arrange file_filter = FileFilter() embeddings, entries = arrange_content() - q_with_no_filter = 'head tail file:"file1.org" file:"file2.org"' + q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"' # Act can_filter = file_filter.can_filter(q_with_no_filter) @@ -93,9 +110,9 @@ def test_multiple_file_filter(): def arrange_content(): embeddings = torch.randn(4, 10) entries = [ - {'compiled': '', 'raw': 'First Entry', 'file': 'file1.org'}, + {'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'}, {'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'}, - {'compiled': '', 'raw': 'Third Entry', 'file': 'file1.org'}, + {'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'}, {'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}] return embeddings, entries