Use SentenceTransformer to disable progress bar when encoding query

The Langchain HuggingFaceEmbeddings wrapper doesn't support disabling
progressbar, not especially for only query but not documents.

This makes the logs noisy with encoding progressbar for each
incremental queries

No features of the Langchain wrapper for SentenceTransformer was
currently being used anyway for now, and we can always switch back to
it if required
This commit is contained in:
Debanjum Singh Solanky
2023-11-04 05:19:50 -07:00
parent dc9946fc03
commit 34b5a86d1d

View File

@@ -1,7 +1,6 @@
from typing import List from typing import List
from langchain.embeddings import HuggingFaceEmbeddings from sentence_transformers import SentenceTransformer, CrossEncoder
from sentence_transformers import CrossEncoder
from khoj.utils.helpers import get_device from khoj.utils.helpers import get_device
from khoj.utils.rawconfig import SearchResponse from khoj.utils.rawconfig import SearchResponse
@@ -10,17 +9,15 @@ from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel: class EmbeddingsModel:
def __init__(self): def __init__(self):
self.model_name = "thenlper/gte-small" self.model_name = "thenlper/gte-small"
encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True} self.encode_kwargs = {"normalize_embeddings": True}
model_kwargs = {"device": get_device()} model_kwargs = {"device": get_device()}
self.embeddings_model = HuggingFaceEmbeddings( self.embeddings_model = SentenceTransformer(self.model_name, **model_kwargs)
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
)
def embed_query(self, query): def embed_query(self, query):
return self.embeddings_model.embed_query(query) return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
def embed_documents(self, docs): def embed_documents(self, docs):
return self.embeddings_model.embed_documents(docs) return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
class CrossEncoderModel: class CrossEncoderModel: