mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Create File Filter to filter files to query. Add tests for file filter
This commit is contained in:
37
src/search_filter/file_filter.py
Normal file
37
src/search_filter/file_filter.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Standard Packages
|
||||||
|
import re
|
||||||
|
import fnmatch
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from src.search_filter.base_filter import BaseFilter
|
||||||
|
|
||||||
|
|
||||||
|
class FileFilter(BaseFilter):
|
||||||
|
file_filter_regex = r'file:"(.+?)" ?'
|
||||||
|
|
||||||
|
def __init__(self, entry_key='file'):
|
||||||
|
self.entry_key = entry_key
|
||||||
|
|
||||||
|
def load(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def can_filter(self, raw_query):
|
||||||
|
return re.search(self.file_filter_regex, raw_query) is not None
|
||||||
|
|
||||||
|
def apply(self, raw_query, raw_entries, raw_embeddings):
|
||||||
|
files_to_search = re.findall(self.file_filter_regex, raw_query)
|
||||||
|
if not files_to_search:
|
||||||
|
return raw_query, raw_entries, raw_embeddings
|
||||||
|
|
||||||
|
query = re.sub(self.file_filter_regex, '', raw_query).strip()
|
||||||
|
included_entry_indices = [id for id, entry in enumerate(raw_entries) for search_file in files_to_search if fnmatch.fnmatch(entry[self.entry_key], search_file)]
|
||||||
|
if not included_entry_indices:
|
||||||
|
return query, [], torch.empty(0)
|
||||||
|
|
||||||
|
entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices]
|
||||||
|
embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices)))
|
||||||
|
|
||||||
|
return query, entries, embeddings
|
||||||
101
tests/test_file_filter.py
Normal file
101
tests/test_file_filter.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
# External Packages
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Application Packages
|
||||||
|
from src.search_filter.file_filter import FileFilter
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_file_filter():
|
||||||
|
# Arrange
|
||||||
|
file_filter = FileFilter()
|
||||||
|
embeddings, entries = arrange_content()
|
||||||
|
q_with_no_filter = 'head tail'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
|
ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert can_filter == False
|
||||||
|
assert ret_query == 'head tail'
|
||||||
|
assert len(ret_emb) == 4
|
||||||
|
assert ret_entries == entries
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_filter_with_partial_match():
|
||||||
|
# Arrange
|
||||||
|
file_filter = FileFilter()
|
||||||
|
embeddings, entries = arrange_content()
|
||||||
|
q_with_no_filter = 'head file:"*.org" tail'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
|
ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert can_filter == True
|
||||||
|
assert ret_query == 'head tail'
|
||||||
|
assert len(ret_emb) == 4
|
||||||
|
assert ret_entries == entries
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_filter_with_non_existent_file():
|
||||||
|
# Arrange
|
||||||
|
file_filter = FileFilter()
|
||||||
|
embeddings, entries = arrange_content()
|
||||||
|
q_with_no_filter = 'head file:"nonexistent.org" tail'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
|
ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert can_filter == True
|
||||||
|
assert ret_query == 'head tail'
|
||||||
|
assert len(ret_emb) == 0
|
||||||
|
assert ret_entries == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_file_filter():
|
||||||
|
# Arrange
|
||||||
|
file_filter = FileFilter()
|
||||||
|
embeddings, entries = arrange_content()
|
||||||
|
q_with_no_filter = 'head file:"file1.org" tail'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
|
ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert can_filter == True
|
||||||
|
assert ret_query == 'head tail'
|
||||||
|
assert len(ret_emb) == 2
|
||||||
|
assert ret_entries == [entries[0], entries[2]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_file_filter():
|
||||||
|
# Arrange
|
||||||
|
file_filter = FileFilter()
|
||||||
|
embeddings, entries = arrange_content()
|
||||||
|
q_with_no_filter = 'head tail file:"file1.org" file:"file2.org"'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
|
ret_query, ret_entries, ret_emb = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert can_filter == True
|
||||||
|
assert ret_query == 'head tail'
|
||||||
|
assert len(ret_emb) == 4
|
||||||
|
assert ret_entries == entries
|
||||||
|
|
||||||
|
|
||||||
|
def arrange_content():
|
||||||
|
embeddings = torch.randn(4, 10)
|
||||||
|
entries = [
|
||||||
|
{'compiled': '', 'raw': 'First Entry', 'file': 'file1.org'},
|
||||||
|
{'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'},
|
||||||
|
{'compiled': '', 'raw': 'Third Entry', 'file': 'file1.org'},
|
||||||
|
{'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}]
|
||||||
|
|
||||||
|
return embeddings, entries
|
||||||
Reference in New Issue
Block a user