diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 3ac52745..a7d1b7dd 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -13,6 +13,7 @@ from src.search_filter.base_filter import BaseFilter from src.utils import state from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model from src.utils.config import TextSearchModel +from src.utils.models import BaseEncoder from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry from src.utils.jsonl import load_jsonl @@ -56,7 +57,7 @@ def extract_entries(jsonl_file) -> list[Entry]: return list(map(Entry.from_dict, load_jsonl(jsonl_file))) -def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False): +def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder: BaseEncoder, embeddings_file, regenerate=False): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" new_entries = [] # Load pre-computed embeddings from file if exists and update them if required diff --git a/src/utils/config.py b/src/utils/config.py index 8ca9e526..ee999ba6 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -9,6 +9,7 @@ import torch # Internal Packages from src.utils.rawconfig import ConversationProcessorConfig, Entry from src.search_filter.base_filter import BaseFilter +from src.utils.models import BaseEncoder class SearchType(str, Enum): @@ -24,7 +25,7 @@ class ProcessorType(str, Enum): class TextSearchModel(): - def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder, cross_encoder, filters: list[BaseFilter], top_k): + def __init__(self, entries: list[Entry], corpus_embeddings: torch.Tensor, bi_encoder: BaseEncoder, cross_encoder, filters: list[BaseFilter], top_k): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder @@ -34,7 +35,7 @@ class TextSearchModel(): class ImageSearchModel(): - def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder): + def __init__(self, image_names, image_embeddings, image_metadata_embeddings, image_encoder: BaseEncoder): self.image_encoder = image_encoder self.image_names = image_names self.image_embeddings = image_embeddings diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 38181ad6..4dd6a78e 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -7,6 +7,12 @@ from collections import OrderedDict from typing import Optional, Union import logging +# External Packages +from sentence_transformers import CrossEncoder + +# Internal Packages +from src.utils.models import BaseEncoder + def is_none_or_empty(item): return item == None or (hasattr(item, '__iter__') and len(item) == 0) or item == '' @@ -45,7 +51,7 @@ def merge_dicts(priority_dict: dict, default_dict: dict): return merged_dict -def load_model(model_name: str, model_type, model_dir=None, device:str=None): +def load_model(model_name: str, model_type, model_dir=None, device:str=None) -> Union[BaseEncoder, CrossEncoder]: "Load model from disk or huggingface" # Construct model path model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None diff --git a/src/utils/models.py b/src/utils/models.py index 9ba204b7..43e02639 100644 --- a/src/utils/models.py +++ b/src/utils/models.py @@ -1,3 +1,6 @@ +# Standard Packages +from abc import ABC, abstractmethod + # External Packages import openai import torch @@ -7,7 +10,15 @@ from tqdm import trange from src.utils.state import processor_config, config_file -class OpenAI: +class BaseEncoder(ABC): + @abstractmethod + def __init__(self, model_name: str, device: torch.device=None, **kwargs): ... + + @abstractmethod + def encode(self, entries: list[str], device:torch.device=None, **kwargs) -> torch.Tensor: ... + + +class OpenAI(BaseEncoder): def __init__(self, model_name, device=None): self.model_name = model_name if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key: @@ -15,7 +26,7 @@ class OpenAI: openai.api_key = processor_config.conversation.openai_api_key self.embedding_dimensions = None - def encode(self, entries: list[str], device=None, **kwargs): + def encode(self, entries, device=None, **kwargs): embedding_tensors = [] for index in trange(0, len(entries)):