mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 05:39:12 +00:00
Load models, corpus embeddings onto GPU device for text search, if available
- Pass device to load models onto from app state. - SentenceTransformer models accept device to load models onto during initialization - Pass device to load corpus embeddings onto from app state
This commit is contained in:
@@ -9,6 +9,7 @@ import torch
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
|
||||
# Internal Packages
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
||||
@@ -32,13 +33,15 @@ def initialize_model(search_config: TextSearchConfig):
|
||||
bi_encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.encoder,
|
||||
model_type = SentenceTransformer)
|
||||
model_type = SentenceTransformer,
|
||||
device=f'{state.device}')
|
||||
|
||||
# The cross-encoder re-ranks the results to improve quality
|
||||
cross_encoder = load_model(
|
||||
model_dir = search_config.model_directory,
|
||||
model_name = search_config.cross_encoder,
|
||||
model_type = CrossEncoder)
|
||||
model_type = CrossEncoder,
|
||||
device=f'{state.device}')
|
||||
|
||||
return bi_encoder, cross_encoder, top_k
|
||||
|
||||
@@ -54,13 +57,12 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
# Load pre-computed embeddings from file if exists
|
||||
if embeddings_file.exists() and not regenerate:
|
||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file))
|
||||
corpus_embeddings = torch.load(get_absolute_path(embeddings_file), map_location=device)
|
||||
if verbose > 0:
|
||||
print(f"Loaded embeddings from {embeddings_file}")
|
||||
|
||||
else: # Else compute the corpus_embeddings from scratch, which can take a while
|
||||
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, show_progress_bar=True)
|
||||
corpus_embeddings.to(device)
|
||||
corpus_embeddings = bi_encoder.encode([entry['compiled'] for entry in entries], convert_to_tensor=True, device=device, show_progress_bar=True)
|
||||
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
|
||||
torch.save(corpus_embeddings, embeddings_file)
|
||||
if verbose > 0:
|
||||
@@ -99,8 +101,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cp
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
start = time.time()
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
|
||||
question_embedding.to(device)
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
|
||||
Reference in New Issue
Block a user