diff --git a/src/search_filter/base_filter.py b/src/search_filter/base_filter.py new file mode 100644 index 00000000..dc079b45 --- /dev/null +++ b/src/search_filter/base_filter.py @@ -0,0 +1,20 @@ +# Standard Packages +from abc import ABC, abstractmethod +from typing import List, Tuple + +# External Packages +import torch + + +class BaseFilter(ABC): + @abstractmethod + def load(self, *args, **kwargs): + pass + + @abstractmethod + def can_filter(self, raw_query:str) -> bool: + pass + + @abstractmethod + def apply(self, query:str, raw_entries:List[str], raw_embeddings: torch.Tensor) -> Tuple[str, List[str], torch.Tensor]: + pass \ No newline at end of file diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index cab47cbb..54a8b625 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -1,7 +1,7 @@ # Standard Packages import re from datetime import timedelta, datetime -from dateutil.relativedelta import relativedelta, MO +from dateutil.relativedelta import relativedelta from math import inf from copy import deepcopy @@ -9,8 +9,11 @@ from copy import deepcopy import torch import dateparser as dtparse +# Internal Packages +from src.search_filter.base_filter import BaseFilter -class DateFilter: + +class DateFilter(BaseFilter): # Date Range Filter Regexes # Example filter queries: # - dt>="yesterday" dt<"tomorrow" diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py index f47ae6b7..9f46edd2 100644 --- a/src/search_filter/word_filter.py +++ b/src/search_filter/word_filter.py @@ -8,6 +8,7 @@ import logging import torch # Internal Packages +from src.search_filter.base_filter import BaseFilter from src.utils.helpers import LRU, resolve_absolute_path from src.utils.config import SearchType @@ -15,7 +16,7 @@ from src.utils.config import SearchType logger = logging.getLogger(__name__) -class WordFilter: +class WordFilter(BaseFilter): # Filter Regex required_regex = r'\+"(\w+)" ?' blocked_regex = r'\-"(\w+)" ?' diff --git a/src/utils/config.py b/src/utils/config.py index a4de6b81..c163e22f 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -2,9 +2,11 @@ from enum import Enum from dataclasses import dataclass from pathlib import Path +from typing import List # Internal Packages from src.utils.rawconfig import ConversationProcessorConfig +from src.search_filter.base_filter import BaseFilter class SearchType(str, Enum): @@ -20,7 +22,7 @@ class ProcessorType(str, Enum): class TextSearchModel(): - def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k): + def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: List[BaseFilter], top_k): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder