mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Create Abstract Base Class for Filters. Make Word, Date Filter Child of BaseFilter
This commit is contained in:
20
src/search_filter/base_filter.py
Normal file
20
src/search_filter/base_filter.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# Standard Packages
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFilter(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def can_filter(self, raw_query:str) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(self, query:str, raw_entries:List[str], raw_embeddings: torch.Tensor) -> Tuple[str, List[str], torch.Tensor]:
|
||||||
|
pass
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
# Standard Packages
|
# Standard Packages
|
||||||
import re
|
import re
|
||||||
from datetime import timedelta, datetime
|
from datetime import timedelta, datetime
|
||||||
from dateutil.relativedelta import relativedelta, MO
|
from dateutil.relativedelta import relativedelta
|
||||||
from math import inf
|
from math import inf
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
@@ -9,8 +9,11 @@ from copy import deepcopy
|
|||||||
import torch
|
import torch
|
||||||
import dateparser as dtparse
|
import dateparser as dtparse
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from src.search_filter.base_filter import BaseFilter
|
||||||
|
|
||||||
class DateFilter:
|
|
||||||
|
class DateFilter(BaseFilter):
|
||||||
# Date Range Filter Regexes
|
# Date Range Filter Regexes
|
||||||
# Example filter queries:
|
# Example filter queries:
|
||||||
# - dt>="yesterday" dt<"tomorrow"
|
# - dt>="yesterday" dt<"tomorrow"
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
|
from src.search_filter.base_filter import BaseFilter
|
||||||
from src.utils.helpers import LRU, resolve_absolute_path
|
from src.utils.helpers import LRU, resolve_absolute_path
|
||||||
from src.utils.config import SearchType
|
from src.utils.config import SearchType
|
||||||
|
|
||||||
@@ -15,7 +16,7 @@ from src.utils.config import SearchType
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WordFilter:
|
class WordFilter(BaseFilter):
|
||||||
# Filter Regex
|
# Filter Regex
|
||||||
required_regex = r'\+"(\w+)" ?'
|
required_regex = r'\+"(\w+)" ?'
|
||||||
blocked_regex = r'\-"(\w+)" ?'
|
blocked_regex = r'\-"(\w+)" ?'
|
||||||
|
|||||||
@@ -2,9 +2,11 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils.rawconfig import ConversationProcessorConfig
|
from src.utils.rawconfig import ConversationProcessorConfig
|
||||||
|
from src.search_filter.base_filter import BaseFilter
|
||||||
|
|
||||||
|
|
||||||
class SearchType(str, Enum):
|
class SearchType(str, Enum):
|
||||||
@@ -20,7 +22,7 @@ class ProcessorType(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class TextSearchModel():
|
class TextSearchModel():
|
||||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k):
|
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: List[BaseFilter], top_k):
|
||||||
self.entries = entries
|
self.entries = entries
|
||||||
self.corpus_embeddings = corpus_embeddings
|
self.corpus_embeddings = corpus_embeddings
|
||||||
self.bi_encoder = bi_encoder
|
self.bi_encoder = bi_encoder
|
||||||
|
|||||||
Reference in New Issue
Block a user