Load models, corpus embeddings onto GPU device for text search, if available

- Pass device to load models onto from app state.
  - SentenceTransformer models accept device to load models onto during initialization
- Pass device to load corpus embeddings onto from app state
This commit is contained in:
Debanjum Singh Solanky
2022-08-20 13:18:31 +03:00
parent 7fe3e844d2
commit 7de9c58a1c
2 changed files with 11 additions and 10 deletions

View File

@@ -9,6 +9,7 @@ import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util from sentence_transformers import SentenceTransformer, CrossEncoder, util
# Internal Packages # Internal Packages
from src.utils import state
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.utils.config import TextSearchModel from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfig, TextContentConfig from src.utils.rawconfig import TextSearchConfig, TextContentConfig
@@ -32,13 +33,15 @@ def initialize_model(search_config: TextSearchConfig):
bi_encoder = load_model( bi_encoder = load_model(
model_dir = search_config.model_directory, model_dir = search_config.model_directory,
model_name = search_config.encoder, model_name = search_config.encoder,
model_type = SentenceTransformer) model_type = SentenceTransformer,
device=f'{state.device}')
# The cross-encoder re-ranks the results to improve quality # The cross-encoder re-ranks the results to improve quality
cross_encoder = load_model( cross_encoder = load_model(
model_dir = search_config.model_directory, model_dir = search_config.model_directory,
model_name = search_config.cross_encoder, model_name = search_config.cross_encoder,
model_type = CrossEncoder) model_type = CrossEncoder,
device=f'{state.device}')
return bi_encoder, cross_encoder, top_k return bi_encoder, cross_encoder, top_k
@@ -54,13 +57,12 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings" "Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
# Load pre-computed embeddings from file if exists # Load pre-computed embeddings from file if exists
if embeddings_file.exists() and not regenerate: if embeddings_file.exists() and not regenerate:
corpus_embeddings = torch.load(get_absolute_path(embeddings_file)) corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=device)
if verbose > 0: if verbose > 0:
print(f"Loaded embeddings from {embeddings_file}") print(f"Loaded embeddings from {embeddings_file}")
else: # Else compute the corpus_embeddings from scratch, which can take a while else: # Else compute the corpus_embeddings from scratch, which can take a while
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, show_progress_bar=True) corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=device, show_progress_bar=True)
corpus_embeddings.to(device)
corpus_embeddings = util.normalize_embeddings(corpus_embeddings) corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
torch.save(corpus_embeddings, embeddings_file) torch.save(corpus_embeddings, embeddings_file)
if verbose > 0: if verbose > 0:
@@ -99,8 +101,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp
# Encode the query using the bi-encoder # Encode the query using the bi-encoder
start = time.time() start = time.time()
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True) question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=device)
question_embedding.to(device)
question_embedding = util.normalize_embeddings(question_embedding) question_embedding = util.normalize_embeddings(question_embedding)
end = time.time() end = time.time()
if verbose > 1: if verbose > 1:

View File

@@ -41,17 +41,17 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict return merged_dict
def load_model(model_name, model_dir, model_type): def load_model(model_name, model_dir, model_type, device:str=None):
"Load model from disk or huggingface" "Load model from disk or huggingface"
# Construct model path # Construct model path
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
# Load model from model_path if it exists there # Load model from model_path if it exists there
if model_path is not None and resolve_absolute_path(model_path).exists(): if model_path is not None and resolve_absolute_path(model_path).exists():
model = model_type(get_absolute_path(model_path)) model = model_type(get_absolute_path(model_path), device=device)
# Else load the model from the model_name # Else load the model from the model_name
else: else:
model = model_type(model_name) model = model_type(model_name, device=device)
if model_path is not None: if model_path is not None:
model.save(model_path) model.save(model_path)