diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index c446f31c..c53048b6 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -9,6 +9,7 @@ import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util # Internal Packages +from src.utils import state from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model from src.utils.config import TextSearchModel from src.utils.rawconfig import TextSearchConfig, TextContentConfig @@ -32,13 +33,15 @@ def initialize_model(search_config: TextSearchConfig): bi_encoder = load_model( model_dir = search_config.model_directory, model_name = search_config.encoder, - model_type = SentenceTransformer) + model_type = SentenceTransformer, + device=f'{state.device}') # The cross-encoder re-ranks the results to improve quality cross_encoder = load_model( model_dir = search_config.model_directory, model_name = search_config.cross_encoder, - model_type = CrossEncoder) + model_type = CrossEncoder, + device=f'{state.device}') return bi_encoder, cross_encoder, top_k @@ -54,13 +57,12 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" # Load pre-computed embeddings from file if exists if embeddings_file.exists() and not regenerate: - corpus_embeddings = torch.load(get_absolute_path(embeddings_file)) + corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=device) if verbose > 0: print(f"Loaded embeddings from {embeddings_file}") else: # Else compute the corpus_embeddings from scratch, which can take a while - corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, show_progress_bar=True) - corpus_embeddings.to(device) + corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=device, show_progress_bar=True) corpus_embeddings = util.normalize_embeddings(corpus_embeddings) torch.save(corpus_embeddings, embeddings_file) if verbose > 0: @@ -99,8 +101,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp # Encode the query using the bi-encoder start = time.time() - question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True) - question_embedding.to(device) + question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=device) question_embedding = util.normalize_embeddings(question_embedding) end = time.time() if verbose > 1: diff --git a/src/utils/helpers.py b/src/utils/helpers.py index e77e656d..66e9d8fc 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -41,17 +41,17 @@ def merge_dicts(priority_dict: dict, default_dict: dict): return merged_dict -def load_model(model_name, model_dir, model_type): +def load_model(model_name, model_dir, model_type, device:str=None): "Load model from disk or huggingface" # Construct model path model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None # Load model from model_path if it exists there if model_path is not None and resolve_absolute_path(model_path).exists(): - model = model_type(get_absolute_path(model_path)) + model = model_type(get_absolute_path(model_path), device=device) # Else load the model from the model_name else: - model = model_type(model_name) + model = model_type(model_name, device=device) if model_path is not None: model.save(model_path)