Time embedding model load for better visibility into app startup time

Loading the embeddings model, even locally seems to be taking much
longer. Use timer to track visibility into embedding, cross-encoder
model load times
This commit is contained in:
Debanjum Singh Solanky
2024-10-06 15:47:22 -07:00
parent 516472a8d5
commit bbbdba3093

View File

@@ -13,7 +13,7 @@ from tenacity import (
) )
from torch import nn from torch import nn
from khoj.utils.helpers import get_device, merge_dicts from khoj.utils.helpers import get_device, merge_dicts, timer
from khoj.utils.rawconfig import SearchResponse from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -37,6 +37,7 @@ class EmbeddingsModel:
self.model_name = model_name self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key self.api_key = embeddings_inference_endpoint_api_key
with timer(f"Loaded embedding model {self.model_name}", logger):
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def inference_server_enabled(self) -> bool: def inference_server_enabled(self) -> bool:
@@ -101,6 +102,7 @@ class CrossEncoderModel:
self.inference_endpoint = cross_encoder_inference_endpoint self.inference_endpoint = cross_encoder_inference_endpoint
self.api_key = cross_encoder_inference_endpoint_api_key self.api_key = cross_encoder_inference_endpoint_api_key
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()}) self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
with timer(f"Loaded cross-encoder model {self.model_name}", logger):
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, **self.model_kwargs) self.cross_encoder_model = CrossEncoder(model_name=self.model_name, **self.model_kwargs)
def inference_server_enabled(self) -> bool: def inference_server_enabled(self) -> bool: