mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Fix, add typing to Filter and TextSearchModel classes
- Changes
- Fix method signatures of BaseFilter subclasses.
Else typing information isn't translating to them
- Explicitly pass `entries: list[Entry]' as arg to `load' method
- Fix type of `raw_entries' arg to `apply' method
to list[Entry] from list[str]
- Rename `raw_entries' arg to `apply' method to `entries'
- Fix `raw_query' arg used in `apply' method of subclasses to `query'
- Set type of entries, corpus_embeddings in TextSearchModel
- Verification
Ran `mypy --config-file .mypy.ini src' to verify typing
This commit is contained in:
@@ -1,13 +1,16 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from src.utils.rawconfig import Entry
|
||||||
|
|
||||||
|
|
||||||
class BaseFilter(ABC):
|
class BaseFilter(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self, *args, **kwargs): ...
|
def load(self, entries: list[Entry], *args, **kwargs): ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def can_filter(self, raw_query:str) -> bool: ...
|
def can_filter(self, raw_query:str) -> bool: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, query:str, raw_entries:list[str]) -> tuple[str, set[int]]: ...
|
def apply(self, query:str, entries: list[Entry]) -> tuple[str, set[int]]: ...
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class DateFilter(BaseFilter):
|
|||||||
self.cache = LRU()
|
self.cache = LRU()
|
||||||
|
|
||||||
|
|
||||||
def load(self, entries, **_):
|
def load(self, entries, *args, **kwargs):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
for id, entry in enumerate(entries):
|
for id, entry in enumerate(entries):
|
||||||
# Extract dates from entry
|
# Extract dates from entry
|
||||||
@@ -53,7 +53,7 @@ class DateFilter(BaseFilter):
|
|||||||
return self.extract_date_range(raw_query) is not None
|
return self.extract_date_range(raw_query) is not None
|
||||||
|
|
||||||
|
|
||||||
def apply(self, query, raw_entries):
|
def apply(self, query, entries):
|
||||||
"Find entries containing any dates that fall within date range specified in query"
|
"Find entries containing any dates that fall within date range specified in query"
|
||||||
# extract date range specified in date filter of query
|
# extract date range specified in date filter of query
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@@ -63,7 +63,7 @@ class DateFilter(BaseFilter):
|
|||||||
|
|
||||||
# if no date in query, return all entries
|
# if no date in query, return all entries
|
||||||
if query_daterange is None:
|
if query_daterange is None:
|
||||||
return query, set(range(len(raw_entries)))
|
return query, set(range(len(entries)))
|
||||||
|
|
||||||
# remove date range filter from query
|
# remove date range filter from query
|
||||||
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
|
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
|
||||||
@@ -77,7 +77,7 @@ class DateFilter(BaseFilter):
|
|||||||
return query, entries_to_include
|
return query, entries_to_include
|
||||||
|
|
||||||
if not self.date_to_entry_ids:
|
if not self.date_to_entry_ids:
|
||||||
self.load(raw_entries)
|
self.load(entries)
|
||||||
|
|
||||||
# find entries containing any dates that fall with date range specified in query
|
# find entries containing any dates that fall with date range specified in query
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|||||||
@@ -31,12 +31,12 @@ class FileFilter(BaseFilter):
|
|||||||
def can_filter(self, raw_query):
|
def can_filter(self, raw_query):
|
||||||
return re.search(self.file_filter_regex, raw_query) is not None
|
return re.search(self.file_filter_regex, raw_query) is not None
|
||||||
|
|
||||||
def apply(self, raw_query, raw_entries):
|
def apply(self, query, entries):
|
||||||
# Extract file filters from raw query
|
# Extract file filters from raw query
|
||||||
start = time.time()
|
start = time.time()
|
||||||
raw_files_to_search = re.findall(self.file_filter_regex, raw_query)
|
raw_files_to_search = re.findall(self.file_filter_regex, query)
|
||||||
if not raw_files_to_search:
|
if not raw_files_to_search:
|
||||||
return raw_query, set(range(len(raw_entries)))
|
return query, set(range(len(entries)))
|
||||||
|
|
||||||
# Convert simple file filters with no path separator into regex
|
# Convert simple file filters with no path separator into regex
|
||||||
# e.g. "file:notes.org" -> "file:.*notes.org"
|
# e.g. "file:notes.org" -> "file:.*notes.org"
|
||||||
@@ -50,7 +50,7 @@ class FileFilter(BaseFilter):
|
|||||||
logger.debug(f"Extract files_to_search from query: {end - start} seconds")
|
logger.debug(f"Extract files_to_search from query: {end - start} seconds")
|
||||||
|
|
||||||
# Return item from cache if exists
|
# Return item from cache if exists
|
||||||
query = re.sub(self.file_filter_regex, '', raw_query).strip()
|
query = re.sub(self.file_filter_regex, '', query).strip()
|
||||||
cache_key = tuple(files_to_search)
|
cache_key = tuple(files_to_search)
|
||||||
if cache_key in self.cache:
|
if cache_key in self.cache:
|
||||||
logger.info(f"Return file filter results from cache")
|
logger.info(f"Return file filter results from cache")
|
||||||
@@ -58,7 +58,7 @@ class FileFilter(BaseFilter):
|
|||||||
return query, included_entry_indices
|
return query, included_entry_indices
|
||||||
|
|
||||||
if not self.file_to_entry_map:
|
if not self.file_to_entry_map:
|
||||||
self.load(raw_entries, regenerate=False)
|
self.load(entries, regenerate=False)
|
||||||
|
|
||||||
# Mark entries that contain any blocked_words for exclusion
|
# Mark entries that contain any blocked_words for exclusion
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class WordFilter(BaseFilter):
|
|||||||
self.cache = LRU()
|
self.cache = LRU()
|
||||||
|
|
||||||
|
|
||||||
def load(self, entries, regenerate=False):
|
def load(self, entries, *args, **kwargs):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
self.cache = {} # Clear cache on filter (re-)load
|
self.cache = {} # Clear cache on filter (re-)load
|
||||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
|
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
|
||||||
@@ -47,20 +47,20 @@ class WordFilter(BaseFilter):
|
|||||||
return len(required_words) != 0 or len(blocked_words) != 0
|
return len(required_words) != 0 or len(blocked_words) != 0
|
||||||
|
|
||||||
|
|
||||||
def apply(self, raw_query, raw_entries):
|
def apply(self, query, entries):
|
||||||
"Find entries containing required and not blocked words specified in query"
|
"Find entries containing required and not blocked words specified in query"
|
||||||
# Separate natural query from required, blocked words filters
|
# Separate natural query from required, blocked words filters
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
|
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
|
||||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)])
|
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
|
||||||
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip()
|
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
|
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
|
||||||
|
|
||||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||||
return query, set(range(len(raw_entries)))
|
return query, set(range(len(entries)))
|
||||||
|
|
||||||
# Return item from cache if exists
|
# Return item from cache if exists
|
||||||
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
|
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
|
||||||
@@ -70,12 +70,12 @@ class WordFilter(BaseFilter):
|
|||||||
return query, included_entry_indices
|
return query, included_entry_indices
|
||||||
|
|
||||||
if not self.word_to_entry_index:
|
if not self.word_to_entry_index:
|
||||||
self.load(raw_entries, regenerate=False)
|
self.load(entries, regenerate=False)
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
# mark entries that contain all required_words for inclusion
|
# mark entries that contain all required_words for inclusion
|
||||||
entries_with_all_required_words = set(range(len(raw_entries)))
|
entries_with_all_required_words = set(range(len(entries)))
|
||||||
if len(required_words) > 0:
|
if len(required_words) > 0:
|
||||||
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
|
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,11 @@ from enum import Enum
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
import torch
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils.rawconfig import ConversationProcessorConfig
|
from src.utils.rawconfig import ConversationProcessorConfig, Entry
|
||||||
from src.search_filter.base_filter import BaseFilter
|
from src.search_filter.base_filter import BaseFilter
|
||||||
|
|
||||||
|
|
||||||
@@ -21,7 +24,7 @@ class ProcessorType(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class TextSearchModel():
|
class TextSearchModel():
|
||||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
|
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
|
||||||
self.entries = entries
|
self.entries = entries
|
||||||
self.corpus_embeddings = corpus_embeddings
|
self.corpus_embeddings = corpus_embeddings
|
||||||
self.bi_encoder = bi_encoder
|
self.bi_encoder = bi_encoder
|
||||||
|
|||||||
Reference in New Issue
Block a user