Create BaseEncoder class. Make OpenAI encoder its child. Use for typing

- Set type of all bi_encoders to BaseEncoder

- Make load_model return type Union of CrossEncoder and BaseEncoder
This commit is contained in:
Debanjum Singh Solanky
2023-01-09 18:08:37 -03:00
parent cf7400759b
commit e5254a8e56
4 changed files with 25 additions and 6 deletions

View File

@@ -13,6 +13,7 @@ from src.search_filter.base_filter import BaseFilter
from src.utils import state from src.utils import state
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
from src.utils.config import TextSearchModel from src.utils.config import TextSearchModel
from src.utils.models import BaseEncoder
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
from src.utils.jsonl import load_jsonl from src.utils.jsonl import load_jsonl
@@ -56,7 +57,7 @@ def extract_entries(jsonl_file) -> list[Entry]:
return list(map(Entry.from_dict, load_jsonl(jsonl_file))) return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False): def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings" "Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = [] new_entries = []
# Load pre-computed embeddings from file if exists and update them if required # Load pre-computed embeddings from file if exists and update them if required

View File

@@ -9,6 +9,7 @@ import torch
# Internal Packages # Internal Packages
from src.utils.rawconfig import ConversationProcessorConfig, Entry from src.utils.rawconfig import ConversationProcessorConfig, Entry
from src.search_filter.base_filter import BaseFilter from src.search_filter.base_filter import BaseFilter
from src.utils.models import BaseEncoder
class SearchType(str, Enum): class SearchType(str, Enum):
@@ -24,7 +25,7 @@ class ProcessorType(str, Enum):
class TextSearchModel(): class TextSearchModel():
def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k): def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder, filters: list[BaseFilter], top_k):
self.entries = entries self.entries = entries
self.corpus_embeddings = corpus_embeddings self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder self.bi_encoder = bi_encoder
@@ -34,7 +35,7 @@ class TextSearchModel():
class ImageSearchModel(): class ImageSearchModel():
def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder): def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder):
self.image_encoder = image_encoder self.image_encoder = image_encoder
self.image_names = image_names self.image_names = image_names
self.image_embeddings = image_embeddings self.image_embeddings = image_embeddings

View File

@@ -7,6 +7,12 @@ from collections import OrderedDict
from typing import Optional, Union from typing import Optional, Union
import logging import logging
# External Packages
from sentence_transformers import CrossEncoder
# Internal Packages
from src.utils.models import BaseEncoder
def is_none_or_empty(item): def is_none_or_empty(item):
return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == '' return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == ''
@@ -45,7 +51,7 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict return merged_dict
def load_model(model_name: str, model_type, model_dir=None, device:str=None): def load_model(model_name: str, model_type, model_dir=None, device:str=None) -> Union[BaseEncoder, CrossEncoder]:
"Load model from disk or huggingface" "Load model from disk or huggingface"
# Construct model path # Construct model path
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None

View File

@@ -1,3 +1,6 @@
# Standard Packages
from abc import ABC, abstractmethod
# External Packages # External Packages
import openai import openai
import torch import torch
@@ -7,7 +10,15 @@ from tqdm import trange
from src.utils.state import processor_config, config_file from src.utils.state import processor_config, config_file
class OpenAI: class BaseEncoder(ABC):
@abstractmethod
def __init__(self, model_name: str, device: torch.device=None, **kwargs): ...
@abstractmethod
def encode(self, entries: list[str], device:torch.device=None, **kwargs) -> torch.Tensor: ...
class OpenAI(BaseEncoder):
def __init__(self, model_name, device=None): def __init__(self, model_name, device=None):
self.model_name = model_name self.model_name = model_name
if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key: if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key:
@@ -15,7 +26,7 @@ class OpenAI:
openai.api_key = processor_config.conversation.openai_api_key openai.api_key = processor_config.conversation.openai_api_key
self.embedding_dimensions = None self.embedding_dimensions = None
def encode(self, entries: list[str], device=None, **kwargs): def encode(self, entries, device=None, **kwargs):
embedding_tensors = [] embedding_tensors = []
for index in trange(0, len(entries)): for index in trange(0, len(entries)):