diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index 15d03f7f..71af5b7d 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -13,7 +13,7 @@ from tenacity import ( ) from torch import nn -from khoj.utils.helpers import get_device, merge_dicts +from khoj.utils.helpers import get_device, merge_dicts, timer from khoj.utils.rawconfig import SearchResponse logger = logging.getLogger(__name__) @@ -37,7 +37,8 @@ class EmbeddingsModel: self.model_name = model_name self.inference_endpoint = embeddings_inference_endpoint self.api_key = embeddings_inference_endpoint_api_key - self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) + with timer(f"Loaded embedding model {self.model_name}", logger): + self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) def inference_server_enabled(self) -> bool: return self.api_key is not None and self.inference_endpoint is not None @@ -101,7 +102,8 @@ class CrossEncoderModel: self.inference_endpoint = cross_encoder_inference_endpoint self.api_key = cross_encoder_inference_endpoint_api_key self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()}) - self.cross_encoder_model = CrossEncoder(model_name=self.model_name, **self.model_kwargs) + with timer(f"Loaded cross-encoder model {self.model_name}", logger): + self.cross_encoder_model = CrossEncoder(model_name=self.model_name, **self.model_kwargs) def inference_server_enabled(self) -> bool: return self.api_key is not None and self.inference_endpoint is not None