Fix, add typing to Filter and TextSearchModel classes

- Changes
  - Fix method signatures of BaseFilter subclasses.
    Else typing information isn't translating to them
  - Explicitly pass `entries: list[Entry]' as arg to `load' method
  - Fix type of `raw_entries' arg to `apply' method
    to list[Entry] from list[str]
  - Rename `raw_entries' arg to `apply' method to `entries'
  - Fix `raw_query' arg used in `apply' method of subclasses to `query'
  - Set type of entries, corpus_embeddings in TextSearchModel

- Verification
  Ran `mypy --config-file .mypy.ini src' to verify typing
This commit is contained in:
Debanjum Singh Solanky
2023-01-09 16:53:18 -03:00
parent d40076fcd6
commit 8498903641
5 changed files with 27 additions and 21 deletions

View File

@@ -1,13 +1,16 @@
# Standard Packages # Standard Packages
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
# Internal Packages
from src.utils.rawconfig import Entry
class BaseFilter(ABC): class BaseFilter(ABC):
@abstractmethod @abstractmethod
def load(self, *args, **kwargs): ... def load(self, entries: list[Entry], *args, **kwargs): ...
@abstractmethod @abstractmethod
def can_filter(self, raw_query:str) -> bool: ... def can_filter(self, raw_query:str) -> bool: ...
@abstractmethod @abstractmethod
def apply(self, query:str, raw_entries:list[str]) -> tuple[str, set[int]]: ... def apply(self, query:str, entries: list[Entry]) -> tuple[str, set[int]]: ...

View File

@@ -33,7 +33,7 @@ class DateFilter(BaseFilter):
self.cache = LRU() self.cache = LRU()
def load(self, entries, **_): def load(self, entries, *args, **kwargs):
start = time.time() start = time.time()
for id, entry in enumerate(entries): for id, entry in enumerate(entries):
# Extract dates from entry # Extract dates from entry
@@ -53,7 +53,7 @@ class DateFilter(BaseFilter):
return self.extract_date_range(raw_query) is not None return self.extract_date_range(raw_query) is not None
def apply(self, query, raw_entries): def apply(self, query, entries):
"Find entries containing any dates that fall within date range specified in query" "Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query # extract date range specified in date filter of query
start = time.time() start = time.time()
@@ -63,7 +63,7 @@ class DateFilter(BaseFilter):
# if no date in query, return all entries # if no date in query, return all entries
if query_daterange is None: if query_daterange is None:
return query, set(range(len(raw_entries))) return query, set(range(len(entries)))
# remove date range filter from query # remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(rf'\s+{self.date_regex}', ' ', query)
@@ -77,7 +77,7 @@ class DateFilter(BaseFilter):
return query, entries_to_include return query, entries_to_include
if not self.date_to_entry_ids: if not self.date_to_entry_ids:
self.load(raw_entries) self.load(entries)
# find entries containing any dates that fall with date range specified in query # find entries containing any dates that fall with date range specified in query
start = time.time() start = time.time()

View File

@@ -31,12 +31,12 @@ class FileFilter(BaseFilter):
def can_filter(self, raw_query): def can_filter(self, raw_query):
return re.search(self.file_filter_regex, raw_query) is not None return re.search(self.file_filter_regex, raw_query) is not None
def apply(self, raw_query, raw_entries): def apply(self, query, entries):
# Extract file filters from raw query # Extract file filters from raw query
start = time.time() start = time.time()
raw_files_to_search = re.findall(self.file_filter_regex, raw_query) raw_files_to_search = re.findall(self.file_filter_regex, query)
if not raw_files_to_search: if not raw_files_to_search:
return raw_query, set(range(len(raw_entries))) return query, set(range(len(entries)))
# Convert simple file filters with no path separator into regex # Convert simple file filters with no path separator into regex
# e.g. "file:notes.org" -> "file:.*notes.org" # e.g. "file:notes.org" -> "file:.*notes.org"
@@ -50,7 +50,7 @@ class FileFilter(BaseFilter):
logger.debug(f"Extract files_to_search from query: {end - start} seconds") logger.debug(f"Extract files_to_search from query: {end - start} seconds")
# Return item from cache if exists # Return item from cache if exists
query = re.sub(self.file_filter_regex, '', raw_query).strip() query = re.sub(self.file_filter_regex, '', query).strip()
cache_key = tuple(files_to_search) cache_key = tuple(files_to_search)
if cache_key in self.cache: if cache_key in self.cache:
logger.info(f"Return file filter results from cache") logger.info(f"Return file filter results from cache")
@@ -58,7 +58,7 @@ class FileFilter(BaseFilter):
return query, included_entry_indices return query, included_entry_indices
if not self.file_to_entry_map: if not self.file_to_entry_map:
self.load(raw_entries, regenerate=False) self.load(entries, regenerate=False)
# Mark entries that contain any blocked_words for exclusion # Mark entries that contain any blocked_words for exclusion
start = time.time() start = time.time()

View File

@@ -23,7 +23,7 @@ class WordFilter(BaseFilter):
self.cache = LRU() self.cache = LRU()
def load(self, entries, regenerate=False): def load(self, entries, *args, **kwargs):
start = time.time() start = time.time()
self.cache = {} # Clear cache on filter (re-)load self.cache = {} # Clear cache on filter (re-)load
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
@@ -47,20 +47,20 @@ class WordFilter(BaseFilter):
return len(required_words) != 0 or len(blocked_words) != 0 return len(required_words) != 0 or len(blocked_words) != 0
def apply(self, raw_query, raw_entries): def apply(self, query, entries):
"Find entries containing required and not blocked words specified in query" "Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters # Separate natural query from required, blocked words filters
start = time.time() start = time.time()
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)]) required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)]) blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip() query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
end = time.time() end = time.time()
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
if len(required_words) == 0 and len(blocked_words) == 0: if len(required_words) == 0 and len(blocked_words) == 0:
return query, set(range(len(raw_entries))) return query, set(range(len(entries)))
# Return item from cache if exists # Return item from cache if exists
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
@@ -70,12 +70,12 @@ class WordFilter(BaseFilter):
return query, included_entry_indices return query, included_entry_indices
if not self.word_to_entry_index: if not self.word_to_entry_index:
self.load(raw_entries, regenerate=False) self.load(entries, regenerate=False)
start = time.time() start = time.time()
# mark entries that contain all required_words for inclusion # mark entries that contain all required_words for inclusion
entries_with_all_required_words = set(range(len(raw_entries))) entries_with_all_required_words = set(range(len(entries)))
if len(required_words) > 0: if len(required_words) > 0:
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words]) entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])

View File

@@ -3,8 +3,11 @@ from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
# External Packages
import torch
# Internal Packages # Internal Packages
from src.utils.rawconfig import ConversationProcessorConfig from src.utils.rawconfig import ConversationProcessorConfig, Entry
from src.search_filter.base_filter import BaseFilter from src.search_filter.base_filter import BaseFilter
@@ -21,7 +24,7 @@ class ProcessorType(str, Enum):
class TextSearchModel(): class TextSearchModel():
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k): def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
self.entries = entries self.entries = entries
self.corpus_embeddings = corpus_embeddings self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder self.bi_encoder = bi_encoder