From 34b5a86d1d5b8daa228945603fe41e3c537dccb4 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 4 Nov 2023 05:19:50 -0700 Subject: [PATCH] Use SentenceTransformer to disable progress bar when encoding query The Langchain HuggingFaceEmbeddings wrapper doesn't support disabling progressbar, not especially for only query but not documents. This makes the logs noisy with encoding progressbar for each incremental queries No features of the Langchain wrapper for SentenceTransformer was currently being used anyway for now, and we can always switch back to it if required --- src/khoj/processor/embeddings.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index fbcddb67..fcd88d80 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -1,7 +1,6 @@ from typing import List -from langchain.embeddings import HuggingFaceEmbeddings -from sentence_transformers import CrossEncoder +from sentence_transformers import SentenceTransformer, CrossEncoder from khoj.utils.helpers import get_device from khoj.utils.rawconfig import SearchResponse @@ -10,17 +9,15 @@ from khoj.utils.rawconfig import SearchResponse class EmbeddingsModel: def __init__(self): self.model_name = "thenlper/gte-small" - encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True} + self.encode_kwargs = {"normalize_embeddings": True} model_kwargs = {"device": get_device()} - self.embeddings_model = HuggingFaceEmbeddings( - model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs - ) + self.embeddings_model = SentenceTransformer(self.model_name, **model_kwargs) def embed_query(self, query): - return self.embeddings_model.embed_query(query) + return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] def embed_documents(self, docs): - return self.embeddings_model.embed_documents(docs) + return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() class CrossEncoderModel: