mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Do not pass embeddings as argument to filter.apply method
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple
|
from typing import List, Set, Tuple
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
@@ -16,5 +16,5 @@ class BaseFilter(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, query:str, raw_entries:List[str], raw_embeddings: torch.Tensor) -> Tuple[str, List[str], torch.Tensor]:
|
def apply(self, query:str, raw_entries:List[str]) -> Tuple[str, Set[int]]:
|
||||||
pass
|
pass
|
||||||
@@ -35,7 +35,7 @@ class DateFilter(BaseFilter):
|
|||||||
return self.extract_date_range(raw_query) is not None
|
return self.extract_date_range(raw_query) is not None
|
||||||
|
|
||||||
|
|
||||||
def apply(self, query, raw_entries, raw_embeddings):
|
def apply(self, query, raw_entries):
|
||||||
"Find entries containing any dates that fall within date range specified in query"
|
"Find entries containing any dates that fall within date range specified in query"
|
||||||
# extract date range specified in date filter of query
|
# extract date range specified in date filter of query
|
||||||
query_daterange = self.extract_date_range(query)
|
query_daterange = self.extract_date_range(query)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class FileFilter(BaseFilter):
|
|||||||
def can_filter(self, raw_query):
|
def can_filter(self, raw_query):
|
||||||
return re.search(self.file_filter_regex, raw_query) is not None
|
return re.search(self.file_filter_regex, raw_query) is not None
|
||||||
|
|
||||||
def apply(self, raw_query, raw_entries, raw_embeddings):
|
def apply(self, raw_query, raw_entries):
|
||||||
# Extract file filters from raw query
|
# Extract file filters from raw query
|
||||||
start = time.time()
|
start = time.time()
|
||||||
raw_files_to_search = re.findall(self.file_filter_regex, raw_query)
|
raw_files_to_search = re.findall(self.file_filter_regex, raw_query)
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ class WordFilter(BaseFilter):
|
|||||||
return len(required_words) != 0 or len(blocked_words) != 0
|
return len(required_words) != 0 or len(blocked_words) != 0
|
||||||
|
|
||||||
|
|
||||||
def apply(self, raw_query, raw_entries, raw_embeddings):
|
def apply(self, raw_query, raw_entries):
|
||||||
"Find entries containing required and not blocked words specified in query"
|
"Find entries containing required and not blocked words specified in query"
|
||||||
# Separate natural query from required, blocked words filters
|
# Separate natural query from required, blocked words filters
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|||||||
@@ -76,11 +76,11 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
|||||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||||
|
|
||||||
# Filter query, entries and embeddings before semantic search
|
# Filter query, entries and embeddings before semantic search
|
||||||
start = time.time()
|
start_filter = time.time()
|
||||||
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
|
|
||||||
included_entry_indices = set(range(len(entries)))
|
included_entry_indices = set(range(len(entries)))
|
||||||
|
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
|
||||||
for filter in filters_in_query:
|
for filter in filters_in_query:
|
||||||
query, included_entry_indices_by_filter = filter.apply(query, entries, corpus_embeddings)
|
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
||||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
||||||
|
|
||||||
# Get entries (and associated embeddings) satisfying all filters
|
# Get entries (and associated embeddings) satisfying all filters
|
||||||
@@ -91,10 +91,10 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
|||||||
entries = [entries[id] for id in included_entry_indices]
|
entries = [entries[id] for id in included_entry_indices]
|
||||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices)))
|
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices)))
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.debug(f"Keep entries satisfying all filter: {end - start} seconds")
|
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
|
||||||
|
|
||||||
end = time.time()
|
end_filter = time.time()
|
||||||
logger.debug(f"Total Filter Time: {end - start:.3f} seconds")
|
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds")
|
||||||
|
|
||||||
if entries is None or len(entries) == 0:
|
if entries is None or len(entries) == 0:
|
||||||
return [], []
|
return [], []
|
||||||
|
|||||||
@@ -18,32 +18,32 @@ def test_date_filter():
|
|||||||
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
|
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
|
||||||
|
|
||||||
q_with_no_date_filter = 'head tail'
|
q_with_no_date_filter = 'head tail'
|
||||||
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries, embeddings)
|
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == 'head tail'
|
||||||
assert entry_indices == {0, 1, 2}
|
assert entry_indices == {0, 1, 2}
|
||||||
|
|
||||||
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
||||||
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries, embeddings)
|
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == 'head tail'
|
||||||
assert entry_indices == set()
|
assert entry_indices == set()
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
||||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings)
|
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == 'head tail'
|
||||||
assert entry_indices == {2}
|
assert entry_indices == {2}
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
||||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings)
|
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == 'head tail'
|
||||||
assert entry_indices == {1}
|
assert entry_indices == {1}
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
||||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings)
|
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == 'head tail'
|
||||||
assert entry_indices == {2}
|
assert entry_indices == {2}
|
||||||
|
|
||||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
||||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries, embeddings)
|
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||||
assert ret_query == 'head tail'
|
assert ret_query == 'head tail'
|
||||||
assert entry_indices == {1, 2}
|
assert entry_indices == {1, 2}
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ def test_no_file_filter():
|
|||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == False
|
assert can_filter == False
|
||||||
@@ -29,7 +29,7 @@ def test_file_filter_with_non_existent_file():
|
|||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
@@ -45,7 +45,7 @@ def test_single_file_filter():
|
|||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
@@ -61,7 +61,7 @@ def test_file_filter_with_partial_match():
|
|||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
@@ -77,7 +77,7 @@ def test_file_filter_with_regex_match():
|
|||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
@@ -93,7 +93,7 @@ def test_multiple_file_filter():
|
|||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# System Packages
|
# System Packages
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from src.utils.config import SearchType
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils.state import model
|
from src.utils.state import model
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ def test_no_word_filter(tmp_path):
|
|||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = word_filter.can_filter(q_with_no_filter)
|
can_filter = word_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == False
|
assert can_filter == False
|
||||||
@@ -30,7 +30,7 @@ def test_word_exclude_filter(tmp_path):
|
|||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = word_filter.can_filter(q_with_exclude_filter)
|
can_filter = word_filter.can_filter(q_with_exclude_filter)
|
||||||
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings)
|
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
@@ -46,7 +46,7 @@ def test_word_include_filter(tmp_path):
|
|||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = word_filter.can_filter(query_with_include_filter)
|
can_filter = word_filter.can_filter(query_with_include_filter)
|
||||||
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries.copy(), embeddings)
|
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
@@ -62,7 +62,7 @@ def test_word_include_and_exclude_filter(tmp_path):
|
|||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
|
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
|
||||||
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings)
|
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
|
|||||||
Reference in New Issue
Block a user