mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-04 21:29:12 +00:00
Load models, corpus embeddings onto GPU device for text search, if available
- Pass device to load models onto from app state. - SentenceTransformer models accept device to load models onto during initialization - Pass device to load corpus embeddings onto from app state
This commit is contained in:
@@ -41,17 +41,17 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
|
||||
return merged_dict
|
||||
|
||||
|
||||
def load_model(model_name, model_dir, model_type):
|
||||
def load_model(model_name, model_dir, model_type, device:str=None):
|
||||
"Load model from disk or huggingface"
|
||||
# Construct model path
|
||||
model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None
|
||||
|
||||
# Load model from model_path if it exists there
|
||||
if model_path is not None and resolve_absolute_path(model_path).exists():
|
||||
model = model_type(get_absolute_path(model_path))
|
||||
model = model_type(get_absolute_path(model_path), device=device)
|
||||
# Else load the model from the model_name
|
||||
else:
|
||||
model = model_type(model_name)
|
||||
model = model_type(model_name, device=device)
|
||||
if model_path is not None:
|
||||
model.save(model_path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user