Load models, corpus embeddings onto GPU device for text search, if available

- Pass device to load models onto from app state.
  - SentenceTransformer models accept device to load models onto during initialization
- Pass device to load corpus embeddings onto from app state
This commit is contained in:
Debanjum Singh Solanky
2022-08-20 13:18:31 +03:00
parent 7fe3e844d2
commit 7de9c58a1c
2 changed files with 11 additions and 10 deletions

View File

@@ -41,17 +41,17 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict
def load_model(model_name, model_dir, model_type):
def load_model(model_name, model_dir, model_type, device:str=None):
"Load model from disk or huggingface"
# Construct model path
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
# Load model from model_path if it exists there
if model_path is not None and resolve_absolute_path(model_path).exists():
model = model_type(get_absolute_path(model_path))
model = model_type(get_absolute_path(model_path), device=device)
# Else load the model from the model_name
else:
model = model_type(model_name)
model = model_type(model_name, device=device)
if model_path is not None:
model.save(model_path)