Fix, add typing to Filter and TextSearchModel classes

- Changes
  - Fix method signatures of BaseFilter subclasses.
    Else typing information isn't translating to them
  - Explicitly pass `entries: list[Entry]' as arg to `load' method
  - Fix type of `raw_entries' arg to `apply' method
    to list[Entry] from list[str]
  - Rename `raw_entries' arg to `apply' method to `entries'
  - Fix `raw_query' arg used in `apply' method of subclasses to `query'
  - Set type of entries, corpus_embeddings in TextSearchModel

- Verification
  Ran `mypy --config-file .mypy.ini src' to verify typing
This commit is contained in:
Debanjum Singh Solanky
2023-01-09 16:53:18 -03:00
parent d40076fcd6
commit 8498903641
5 changed files with 27 additions and 21 deletions

View File

@@ -3,8 +3,11 @@ from enum import Enum
from dataclasses import dataclass
from pathlib import Path
# External Packages
import torch
# Internal Packages
from src.utils.rawconfig import ConversationProcessorConfig
from src.utils.rawconfig import ConversationProcessorConfig, Entry
from src.search_filter.base_filter import BaseFilter
@@ -21,7 +24,7 @@ class ProcessorType(str, Enum):
class TextSearchModel():
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder