Create Abstract Base Class for Filters. Make Word, Date Filter Child of BaseFilter

This commit is contained in:
Debanjum Singh Solanky
2022-09-04 18:05:38 +03:00
parent c9f6200007
commit e4418746f2
4 changed files with 30 additions and 4 deletions

View File

@@ -0,0 +1,20 @@
# Standard Packages
from abc import ABC, abstractmethod
from typing import List, Tuple
# External Packages
import torch
class BaseFilter(ABC):
@abstractmethod
def load(self, *args, **kwargs):
pass
@abstractmethod
def can_filter(self, raw_query:str) -> bool:
pass
@abstractmethod
def apply(self, query:str, raw_entries:List[str], raw_embeddings: torch.Tensor) -> Tuple[str, List[str], torch.Tensor]:
pass

View File

@@ -1,7 +1,7 @@
# Standard Packages # Standard Packages
import re import re
from datetime import timedelta, datetime from datetime import timedelta, datetime
from dateutil.relativedelta import relativedelta, MO from dateutil.relativedelta import relativedelta
from math import inf from math import inf
from copy import deepcopy from copy import deepcopy
@@ -9,8 +9,11 @@ from copy import deepcopy
import torch import torch
import dateparser as dtparse import dateparser as dtparse
# Internal Packages
from src.search_filter.base_filter import BaseFilter
class DateFilter:
class DateFilter(BaseFilter):
# Date Range Filter Regexes # Date Range Filter Regexes
# Example filter queries: # Example filter queries:
# - dt>="yesterday" dt<"tomorrow" # - dt>="yesterday" dt<"tomorrow"

View File

@@ -8,6 +8,7 @@ import logging
import torch import torch
# Internal Packages # Internal Packages
from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU, resolve_absolute_path from src.utils.helpers import LRU, resolve_absolute_path
from src.utils.config import SearchType from src.utils.config import SearchType
@@ -15,7 +16,7 @@ from src.utils.config import SearchType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WordFilter: class WordFilter(BaseFilter):
# Filter Regex # Filter Regex
required_regex = r'\+"(\w+)" ?' required_regex = r'\+"(\w+)" ?'
blocked_regex = r'\-"(\w+)" ?' blocked_regex = r'\-"(\w+)" ?'

View File

@@ -2,9 +2,11 @@
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List
# Internal Packages # Internal Packages
from src.utils.rawconfig import ConversationProcessorConfig from src.utils.rawconfig import ConversationProcessorConfig
from src.search_filter.base_filter import BaseFilter
class SearchType(str, Enum): class SearchType(str, Enum):
@@ -20,7 +22,7 @@ class ProcessorType(str, Enum):
class TextSearchModel(): class TextSearchModel():
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k): def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: List[BaseFilter], top_k):
self.entries = entries self.entries = entries
self.corpus_embeddings = corpus_embeddings self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder self.bi_encoder = bi_encoder