mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 05:39:11 +00:00
Use SentenceTransformer to disable progress bar when encoding query
The Langchain HuggingFaceEmbeddings wrapper doesn't support disabling progressbar, not especially for only query but not documents. This makes the logs noisy with encoding progressbar for each incremental queries No features of the Langchain wrapper for SentenceTransformer was currently being used anyway for now, and we can always switch back to it if required
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from sentence_transformers import CrossEncoder
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
|
||||
from khoj.utils.helpers import get_device
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
@@ -10,17 +9,15 @@ from khoj.utils.rawconfig import SearchResponse
|
||||
class EmbeddingsModel:
|
||||
def __init__(self):
|
||||
self.model_name = "thenlper/gte-small"
|
||||
encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True}
|
||||
self.encode_kwargs = {"normalize_embeddings": True}
|
||||
model_kwargs = {"device": get_device()}
|
||||
self.embeddings_model = HuggingFaceEmbeddings(
|
||||
model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs
|
||||
)
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **model_kwargs)
|
||||
|
||||
def embed_query(self, query):
|
||||
return self.embeddings_model.embed_query(query)
|
||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||
|
||||
def embed_documents(self, docs):
|
||||
return self.embeddings_model.embed_documents(docs)
|
||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||
|
||||
|
||||
class CrossEncoderModel:
|
||||
|
||||
Reference in New Issue
Block a user