Time embedding model load for better visibility into app startup time

Loading the embeddings model, even locally seems to be taking much
longer. Use timer to track visibility into embedding, cross-encoder
model load times
This commit is contained in:
Debanjum Singh Solanky
2024-10-06 15:47:22 -07:00
parent 516472a8d5
commit bbbdba3093

View File

@@ -13,7 +13,7 @@ from tenacity import (
)
from torch import nn
from khoj.utils.helpers import get_device, merge_dicts
from khoj.utils.helpers import get_device, merge_dicts, timer
from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__)
@@ -37,7 +37,8 @@ class EmbeddingsModel:
self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
with timer(f"Loaded embedding model {self.model_name}", logger):
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None
@@ -101,7 +102,8 @@ class CrossEncoderModel:
self.inference_endpoint = cross_encoder_inference_endpoint
self.api_key = cross_encoder_inference_endpoint_api_key
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, **self.model_kwargs)
with timer(f"Loaded cross-encoder model {self.model_name}", logger):
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, **self.model_kwargs)
def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None