mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-04 13:20:17 +00:00
Time embedding model load for better visibility into app startup time
Loading the embeddings model, even locally seems to be taking much longer. Use timer to track visibility into embedding, cross-encoder model load times
This commit is contained in:
@@ -13,7 +13,7 @@ from tenacity import (
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from khoj.utils.helpers import get_device, merge_dicts
|
||||
from khoj.utils.helpers import get_device, merge_dicts, timer
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,7 +37,8 @@ class EmbeddingsModel:
|
||||
self.model_name = model_name
|
||||
self.inference_endpoint = embeddings_inference_endpoint
|
||||
self.api_key = embeddings_inference_endpoint_api_key
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
with timer(f"Loaded embedding model {self.model_name}", logger):
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
|
||||
def inference_server_enabled(self) -> bool:
|
||||
return self.api_key is not None and self.inference_endpoint is not None
|
||||
@@ -101,7 +102,8 @@ class CrossEncoderModel:
|
||||
self.inference_endpoint = cross_encoder_inference_endpoint
|
||||
self.api_key = cross_encoder_inference_endpoint_api_key
|
||||
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
|
||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, **self.model_kwargs)
|
||||
with timer(f"Loaded cross-encoder model {self.model_name}", logger):
|
||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, **self.model_kwargs)
|
||||
|
||||
def inference_server_enabled(self) -> bool:
|
||||
return self.api_key is not None and self.inference_endpoint is not None
|
||||
|
||||
Reference in New Issue
Block a user