mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Add model_config for cross-encoder model (#885) from aam-at/feature/crossencoder_model_config
Add `model_config' for the cross-encoder model, so the server admin can use models which require the `trust_remote_code' argument to run locally
This commit is contained in:
@@ -263,6 +263,7 @@ def configure_server(
|
|||||||
model.cross_encoder,
|
model.cross_encoder,
|
||||||
model.cross_encoder_inference_endpoint,
|
model.cross_encoder_inference_endpoint,
|
||||||
model.cross_encoder_inference_endpoint_api_key,
|
model.cross_encoder_inference_endpoint_api_key,
|
||||||
|
model_kwargs=model.cross_encoder_model_config,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
# Generated by Django 5.0.7 on 2024-08-07 09:12
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0055_alter_agent_style_icon"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="searchmodelconfig",
|
||||||
|
name="cross_encoder_model_config",
|
||||||
|
field=models.JSONField(blank=True, default=dict),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -259,6 +259,8 @@ class SearchModelConfig(BaseModel):
|
|||||||
bi_encoder_docs_encode_config = models.JSONField(default=dict, blank=True)
|
bi_encoder_docs_encode_config = models.JSONField(default=dict, blank=True)
|
||||||
# Cross-encoder model of sentence-transformer type to load from HuggingFace
|
# Cross-encoder model of sentence-transformer type to load from HuggingFace
|
||||||
cross_encoder = models.CharField(max_length=200, default="mixedbread-ai/mxbai-rerank-xsmall-v1")
|
cross_encoder = models.CharField(max_length=200, default="mixedbread-ai/mxbai-rerank-xsmall-v1")
|
||||||
|
# Config passed to the cross-encoder model constructor. E.g. device="cuda:0", trust_remote_server=True etc.
|
||||||
|
cross_encoder_model_config = models.JSONField(default=dict, blank=True)
|
||||||
# Inference server API endpoint to use for embeddings inference. Bi-encoder model should be hosted on this server
|
# Inference server API endpoint to use for embeddings inference. Bi-encoder model should be hosted on this server
|
||||||
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
|
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
|
||||||
|
|||||||
@@ -95,11 +95,13 @@ class CrossEncoderModel:
|
|||||||
model_name: str = "mixedbread-ai/mxbai-rerank-xsmall-v1",
|
model_name: str = "mixedbread-ai/mxbai-rerank-xsmall-v1",
|
||||||
cross_encoder_inference_endpoint: str = None,
|
cross_encoder_inference_endpoint: str = None,
|
||||||
cross_encoder_inference_endpoint_api_key: str = None,
|
cross_encoder_inference_endpoint_api_key: str = None,
|
||||||
|
model_kwargs: dict = {},
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
|
||||||
self.inference_endpoint = cross_encoder_inference_endpoint
|
self.inference_endpoint = cross_encoder_inference_endpoint
|
||||||
self.api_key = cross_encoder_inference_endpoint_api_key
|
self.api_key = cross_encoder_inference_endpoint_api_key
|
||||||
|
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
|
||||||
|
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, **self.model_kwargs)
|
||||||
|
|
||||||
def inference_server_enabled(self) -> bool:
|
def inference_server_enabled(self) -> bool:
|
||||||
return self.api_key is not None and self.inference_endpoint is not None
|
return self.api_key is not None and self.inference_endpoint is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user