diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 419bf950..38b8223f 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -216,6 +216,9 @@ def configure_server( model.bi_encoder, model.embeddings_inference_endpoint, model.embeddings_inference_endpoint_api_key, + query_encode_kwargs=model.bi_encoder_query_encode_config, + docs_encode_kwargs=model.bi_encoder_docs_encode_config, + model_kwargs=model.bi_encoder_model_config, ) } ) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 15f396f1..58b8b729 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -182,6 +182,9 @@ class SearchModelConfig(BaseModel): name = models.CharField(max_length=200, default="default") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT) bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") + bi_encoder_model_config = models.JSONField(default=dict) + bi_encoder_query_encode_config = models.JSONField(default=dict) + bi_encoder_docs_encode_config = models.JSONField(default=dict) cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index ec8e08f0..19e986af 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -13,7 +13,7 @@ from tenacity import ( ) from torch import nn -from khoj.utils.helpers import get_device +from khoj.utils.helpers import get_device, merge_dicts from khoj.utils.rawconfig import SearchResponse logger = logging.getLogger(__name__) @@ -25,9 +25,15 @@ class EmbeddingsModel: model_name: str = "thenlper/gte-small", embeddings_inference_endpoint: str = None, embeddings_inference_endpoint_api_key: str = None, + query_encode_kwargs: dict = {}, + docs_encode_kwargs: dict = {}, + model_kwargs: dict = {}, ): - self.encode_kwargs = {"normalize_embeddings": True} - self.model_kwargs = {"device": get_device()} + default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True} + default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True} + self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs) + self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs) + self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()}) self.model_name = model_name self.inference_endpoint = embeddings_inference_endpoint self.api_key = embeddings_inference_endpoint_api_key @@ -39,7 +45,7 @@ class EmbeddingsModel: def embed_query(self, query): if self.inference_server_enabled(): return self.embed_with_api([query])[0] - return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] + return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0] @retry( retry=retry_if_exception_type(requests.exceptions.HTTPError), @@ -70,7 +76,7 @@ class EmbeddingsModel: logger.warning( f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead." ) - return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() + return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() # break up the docs payload in chunks of 1000 to avoid hitting rate limits embeddings = [] with tqdm.tqdm(total=len(docs)) as pbar: @@ -80,7 +86,7 @@ class EmbeddingsModel: embeddings += generated_embeddings pbar.update(1000) return embeddings - return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() + return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() class CrossEncoderModel: