Use SentenceTransformer to disable progress bar when encoding query

The Langchain HuggingFaceEmbeddings wrapper doesn't support disabling
progressbar, not especially for only query but not documents.

This makes the logs noisy with encoding progressbar for each
incremental queries

No features of the Langchain wrapper for SentenceTransformer was
currently being used anyway for now, and we can always switch back to
it if required
This commit is contained in:
Debanjum Singh Solanky
2023-11-04 05:19:50 -07:00
parent dc9946fc03
commit 34b5a86d1d

View File

@@ -1,7 +1,6 @@
from typing import List
from langchain.embeddings import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
from sentence_transformers import SentenceTransformer, CrossEncoder
from khoj.utils.helpers import get_device
from khoj.utils.rawconfig import SearchResponse
@@ -10,17 +9,15 @@ from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel:
def __init__(self):
self.model_name = "thenlper/gte-small"
encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True}
self.encode_kwargs = {"normalize_embeddings": True}
model_kwargs = {"device": get_device()}
self.embeddings_model = HuggingFaceEmbeddings(
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
)
self.embeddings_model = SentenceTransformer(self.model_name, **model_kwargs)
def embed_query(self, query):
return self.embeddings_model.embed_query(query)
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
def embed_documents(self, docs):
return self.embeddings_model.embed_documents(docs)
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
class CrossEncoderModel: