From 823f8d58bb2926e02952629e8f4226ee0fae2750 Mon Sep 17 00:00:00 2001 From: Alexander Matyasko Date: Wed, 7 Aug 2024 17:58:21 +0800 Subject: [PATCH] Add model_config for crossencoder model Add model_config for crossencoder model, so the user can use models which require trust_remote_code. --- src/khoj/configure.py | 1 + ...rchmodelconfig_cross_encoder_model_config.py | 17 +++++++++++++++++ src/khoj/database/models/__init__.py | 2 ++ src/khoj/processor/embeddings.py | 4 +++- 4 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 src/khoj/database/migrations/0056_searchmodelconfig_cross_encoder_model_config.py diff --git a/src/khoj/configure.py b/src/khoj/configure.py index a77ade87..42654107 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -263,6 +263,7 @@ def configure_server( model.cross_encoder, model.cross_encoder_inference_endpoint, model.cross_encoder_inference_endpoint_api_key, + model_kwargs=model.cross_encoder_model_config, ) } ) diff --git a/src/khoj/database/migrations/0056_searchmodelconfig_cross_encoder_model_config.py b/src/khoj/database/migrations/0056_searchmodelconfig_cross_encoder_model_config.py new file mode 100644 index 00000000..362c9960 --- /dev/null +++ b/src/khoj/database/migrations/0056_searchmodelconfig_cross_encoder_model_config.py @@ -0,0 +1,17 @@ +# Generated by Django 5.0.7 on 2024-08-07 09:12 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0055_alter_agent_style_icon"), + ] + + operations = [ + migrations.AddField( + model_name="searchmodelconfig", + name="cross_encoder_model_config", + field=models.JSONField(blank=True, default=dict), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 13340649..bf16b781 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -259,6 +259,8 @@ class SearchModelConfig(BaseModel): bi_encoder_docs_encode_config = models.JSONField(default=dict, blank=True) # Cross-encoder model of sentence-transformer type to load from HuggingFace cross_encoder = models.CharField(max_length=200, default="mixedbread-ai/mxbai-rerank-xsmall-v1") + # Config passed to the cross-encoder model constructor. E.g. device="cuda:0", trust_remote_server=True etc. + cross_encoder_model_config = models.JSONField(default=dict, blank=True) # Inference server API endpoint to use for embeddings inference. Bi-encoder model should be hosted on this server embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) # Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index ce5d6bf1..15d03f7f 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -95,11 +95,13 @@ class CrossEncoderModel: model_name: str = "mixedbread-ai/mxbai-rerank-xsmall-v1", cross_encoder_inference_endpoint: str = None, cross_encoder_inference_endpoint_api_key: str = None, + model_kwargs: dict = {}, ): self.model_name = model_name - self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device()) self.inference_endpoint = cross_encoder_inference_endpoint self.api_key = cross_encoder_inference_endpoint_api_key + self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()}) + self.cross_encoder_model = CrossEncoder(model_name=self.model_name, **self.model_kwargs) def inference_server_enabled(self) -> bool: return self.api_key is not None and self.inference_endpoint is not None