Create BaseEncoder class. Make OpenAI encoder its child. Use for typing

- Set type of all bi_encoders to BaseEncoder

- Make load_model return type Union of CrossEncoder and BaseEncoder
This commit is contained in:
Debanjum Singh Solanky
2023-01-09 18:08:37 -03:00
parent cf7400759b
commit e5254a8e56
4 changed files with 25 additions and 6 deletions

View File

@@ -9,6 +9,7 @@ import torch
# Internal Packages
from src.utils.rawconfig import ConversationProcessorConfig, Entry
from src.search_filter.base_filter import BaseFilter
from src.utils.models import BaseEncoder
class SearchType(str, Enum):
@@ -24,7 +25,7 @@ class ProcessorType(str, Enum):
class TextSearchModel():
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k):
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder, filters: list[BaseFilter], top_k):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder
@@ -34,7 +35,7 @@ class TextSearchModel():
class ImageSearchModel():
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder):
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
self.image_encoder = image_encoder
self.image_names = image_names
self.image_embeddings = image_embeddings

View File

@@ -7,6 +7,12 @@ from collections import OrderedDict
from typing import Optional, Union
import logging
# External Packages
from sentence_transformers import CrossEncoder
# Internal Packages
from src.utils.models import BaseEncoder
def is_none_or_empty(item):
return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == ''
@@ -45,7 +51,7 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict
def load_model(model_name: str, model_type, model_dir=None, device:str=None):
def load_model(model_name: str, model_type, model_dir=None, device:str=None) -> Union[BaseEncoder, CrossEncoder]:
"Load model from disk or huggingface"
# Construct model path
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None

View File

@@ -1,3 +1,6 @@
# Standard Packages
from abc import ABC, abstractmethod
# External Packages
import openai
import torch
@@ -7,7 +10,15 @@ from tqdm import trange
from src.utils.state import processor_config, config_file
class OpenAI:
class BaseEncoder(ABC):
@abstractmethod
def __init__(self, model_name: str, device: torch.device=None, **kwargs): ...
@abstractmethod
def encode(self, entries: list[str], device:torch.device=None, **kwargs) -> torch.Tensor: ...
class OpenAI(BaseEncoder):
def __init__(self, model_name, device=None):
self.model_name = model_name
if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key:
@@ -15,7 +26,7 @@ class OpenAI:
openai.api_key = processor_config.conversation.openai_api_key
self.embedding_dimensions = None
def encode(self, entries: list[str], device=None, **kwargs):
def encode(self, entries, device=None, **kwargs):
embedding_tensors = []
for index in trange(0, len(entries)):