From 7eaf9367fec5c0b8adebdf66fd5247508973d001 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 24 Apr 2024 09:02:20 +0530 Subject: [PATCH] Support more embedding models by making query, docs encoding configurable Most newer, better embeddings models add a query, docs prefix when encoding. Previously Khoj admins couldn't configure these, so it wasn't possible to use these newer models. This change allows configuring the kwargs passed to the query, docs encoders by updating the search config in the database. --- src/khoj/configure.py | 3 +++ src/khoj/database/models/__init__.py | 3 +++ src/khoj/processor/embeddings.py | 18 ++++++++++++------ 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 419bf950..38b8223f 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -216,6 +216,9 @@ def configure_server( model.bi_encoder, model.embeddings_inference_endpoint, model.embeddings_inference_endpoint_api_key, + query_encode_kwargs=model.bi_encoder_query_encode_config, + docs_encode_kwargs=model.bi_encoder_docs_encode_config, + model_kwargs=model.bi_encoder_model_config, ) } ) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 15f396f1..58b8b729 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -182,6 +182,9 @@ class SearchModelConfig(BaseModel): name = models.CharField(max_length=200, default="default") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT) bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") + bi_encoder_model_config = models.JSONField(default=dict) + bi_encoder_query_encode_config = models.JSONField(default=dict) + bi_encoder_docs_encode_config = models.JSONField(default=dict) cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index ec8e08f0..19e986af 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -13,7 +13,7 @@ from tenacity import ( ) from torch import nn -from khoj.utils.helpers import get_device +from khoj.utils.helpers import get_device, merge_dicts from khoj.utils.rawconfig import SearchResponse logger = logging.getLogger(__name__) @@ -25,9 +25,15 @@ class EmbeddingsModel: model_name: str = "thenlper/gte-small", embeddings_inference_endpoint: str = None, embeddings_inference_endpoint_api_key: str = None, + query_encode_kwargs: dict = {}, + docs_encode_kwargs: dict = {}, + model_kwargs: dict = {}, ): - self.encode_kwargs = {"normalize_embeddings": True} - self.model_kwargs = {"device": get_device()} + default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True} + default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True} + self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs) + self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs) + self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()}) self.model_name = model_name self.inference_endpoint = embeddings_inference_endpoint self.api_key = embeddings_inference_endpoint_api_key @@ -39,7 +45,7 @@ class EmbeddingsModel: def embed_query(self, query): if self.inference_server_enabled(): return self.embed_with_api([query])[0] - return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] + return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0] @retry( retry=retry_if_exception_type(requests.exceptions.HTTPError), @@ -70,7 +76,7 @@ class EmbeddingsModel: logger.warning( f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead." ) - return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() + return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() # break up the docs payload in chunks of 1000 to avoid hitting rate limits embeddings = [] with tqdm.tqdm(total=len(docs)) as pbar: @@ -80,7 +86,7 @@ class EmbeddingsModel: embeddings += generated_embeddings pbar.update(1000) return embeddings - return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() + return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() class CrossEncoderModel: