Make Search Models More Configurable. Upgrade Default Cross-Encoder (#722)

- Upgrade default cross-encoder to mixedbread ai's mxbai-rerank-xsmall
- Support more embedding models by making query, docs encoding configurable
This commit is contained in:
Debanjum
2024-04-25 13:55:49 +05:30
committed by GitHub
6 changed files with 67 additions and 10 deletions

View File

@@ -216,6 +216,9 @@ def configure_server(
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,
)
}
)

View File

@@ -0,0 +1,32 @@
# Generated by Django 4.2.10 on 2024-04-24 04:19
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0036_delete_offlinechatprocessorconversationconfig"),
]
operations = [
migrations.AddField(
model_name="searchmodelconfig",
name="bi_encoder_docs_encode_config",
field=models.JSONField(default=dict),
),
migrations.AddField(
model_name="searchmodelconfig",
name="bi_encoder_model_config",
field=models.JSONField(default=dict),
),
migrations.AddField(
model_name="searchmodelconfig",
name="bi_encoder_query_encode_config",
field=models.JSONField(default=dict),
),
migrations.AlterField(
model_name="searchmodelconfig",
name="cross_encoder",
field=models.CharField(default="mixedbread-ai/mxbai-rerank-xsmall-v1", max_length=200),
),
]

View File

@@ -179,13 +179,27 @@ class SearchModelConfig(BaseModel):
class ModelType(models.TextChoices):
TEXT = "text"
# This is the model name exposed to users on their settings page
name = models.CharField(max_length=200, default="default")
# Type of content the model can generate embeddings for
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
# Bi-encoder model of sentence-transformer type to load from HuggingFace
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
# Config passed to the sentence-transformer model constructor. E.g device="cuda:0", trust_remote_server=True etc.
bi_encoder_model_config = models.JSONField(default=dict)
# Query encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models
bi_encoder_query_encode_config = models.JSONField(default=dict)
# Docs encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models
bi_encoder_docs_encode_config = models.JSONField(default=dict)
# Cross-encoder model of sentence-transformer type to load from HuggingFace
cross_encoder = models.CharField(max_length=200, default="mixedbread-ai/mxbai-rerank-xsmall-v1")
# Inference server API endpoint to use for embeddings inference. Bi-encoder model should be hosted on this server
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)

View File

@@ -13,7 +13,7 @@ from tenacity import (
)
from torch import nn
from khoj.utils.helpers import get_device
from khoj.utils.helpers import get_device, merge_dicts
from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__)
@@ -25,9 +25,15 @@ class EmbeddingsModel:
model_name: str = "thenlper/gte-small",
embeddings_inference_endpoint: str = None,
embeddings_inference_endpoint_api_key: str = None,
query_encode_kwargs: dict = {},
docs_encode_kwargs: dict = {},
model_kwargs: dict = {},
):
self.encode_kwargs = {"normalize_embeddings": True}
self.model_kwargs = {"device": get_device()}
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key
@@ -39,7 +45,7 @@ class EmbeddingsModel:
def embed_query(self, query):
if self.inference_server_enabled():
return self.embed_with_api([query])[0]
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
@retry(
retry=retry_if_exception_type(requests.exceptions.HTTPError),
@@ -70,7 +76,7 @@ class EmbeddingsModel:
logger.warning(
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
)
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
embeddings = []
with tqdm.tqdm(total=len(docs)) as pbar:
@@ -80,13 +86,13 @@ class EmbeddingsModel:
embeddings += generated_embeddings
pbar.update(1000)
return embeddings
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
class CrossEncoderModel:
def __init__(
self,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
model_name: str = "mixedbread-ai/mxbai-rerank-xsmall-v1",
cross_encoder_inference_endpoint: str = None,
cross_encoder_inference_endpoint_api_key: str = None,
):