diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index fbcddb67..fcd88d80 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -1,7 +1,6 @@ from typing import List -from langchain.embeddings import HuggingFaceEmbeddings -from sentence_transformers import CrossEncoder +from sentence_transformers import SentenceTransformer, CrossEncoder from khoj.utils.helpers import get_device from khoj.utils.rawconfig import SearchResponse @@ -10,17 +9,15 @@ from khoj.utils.rawconfig import SearchResponse class EmbeddingsModel: def __init__(self): self.model_name = "thenlper/gte-small" - encode_kwargs = {"normalize_embeddings": True, "show_progress_bar": True} + self.encode_kwargs = {"normalize_embeddings": True} model_kwargs = {"device": get_device()} - self.embeddings_model = HuggingFaceEmbeddings( - model_name=self.model_name, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs - ) + self.embeddings_model = SentenceTransformer(self.model_name, **model_kwargs) def embed_query(self, query): - return self.embeddings_model.embed_query(query) + return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] def embed_documents(self, docs): - return self.embeddings_model.embed_documents(docs) + return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() class CrossEncoderModel: