Make Search Models More Configurable. Upgrade Default Cross-Encoder (#722)

- Upgrade default cross-encoder to mixedbread ai's mxbai-rerank-xsmall
- Support more embedding models by making query, docs encoding configurable
This commit is contained in:
Debanjum
2024-04-25 13:55:49 +05:30
committed by GitHub
6 changed files with 67 additions and 10 deletions

View File

@@ -7,9 +7,11 @@ sidebar_position: 3
## Search across Different Languages (Self-Hosting) ## Search across Different Languages (Self-Hosting)
To search for notes in multiple, different languages, you can use a [multi-lingual model](https://www.sbert.net/docs/pretrained_models.html#multi-lingual-models).<br /> To search for notes in multiple, different languages, you can use a [multi-lingual model](https://www.sbert.net/docs/pretrained_models.html#multi-lingual-models).<br />
For example, the [paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2) supports [50+ languages](https://www.sbert.net/docs/pretrained_models.html#:~:text=we%20used%20the%20following%2050%2B%20languages), has good search quality and speed. To use it: For example, the [paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2) supports [50+ languages](https://www.sbert.net/docs/pretrained_models.html#:~:text=we%20used%20the%20following%2050%2B%20languages), has good search quality and speed. To use it:
1. Manually update the search config in server's admin settings page. Go to [the search config](http://localhost:42110/server/admin/database/searchmodelconfig/). Either create a new one, if none exists, or update the existing one. Set the bi_encoder to `sentence-transformers/multi-qa-MiniLM-L6-cos-v1` and the cross_encoder to `cross-encoder/ms-marco-MiniLM-L-6-v2`. 1. Manually update the search config in server's admin settings page. Go to [the search config](http://localhost:42110/server/admin/database/searchmodelconfig/). Either create a new one, if none exists, or update the existing one. Set the bi_encoder to `sentence-transformers/multi-qa-MiniLM-L6-cos-v1` and the cross_encoder to `mixedbread-ai/mxbai-rerank-xsmall-v1`.
2. Regenerate your content index from all the relevant clients. This step is very important, as you'll need to re-encode all your content with the new model. 2. Regenerate your content index from all the relevant clients. This step is very important, as you'll need to re-encode all your content with the new model.
Note: If you use a search model that expects a prefix (e.g [mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)) to the query (or docs) string before encoding. Update the `bi_encoder_query_encode_config` field with `{prompt: <prefix-prompt>}`. Eg. `{prompt: "Represent this query for searching documents"}`. You can pass a valid JSON object that the SentenceTransformer `encode` function accepts
## Query Filters ## Query Filters
Use structured query syntax to filter entries from your knowledge based used by search results or chat responses. Use structured query syntax to filter entries from your knowledge based used by search results or chat responses.

View File

@@ -216,6 +216,9 @@ def configure_server(
model.bi_encoder, model.bi_encoder,
model.embeddings_inference_endpoint, model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key, model.embeddings_inference_endpoint_api_key,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,
) )
} }
) )

View File

@@ -0,0 +1,32 @@
# Generated by Django 4.2.10 on 2024-04-24 04:19
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0036_delete_offlinechatprocessorconversationconfig"),
]
operations = [
migrations.AddField(
model_name="searchmodelconfig",
name="bi_encoder_docs_encode_config",
field=models.JSONField(default=dict),
),
migrations.AddField(
model_name="searchmodelconfig",
name="bi_encoder_model_config",
field=models.JSONField(default=dict),
),
migrations.AddField(
model_name="searchmodelconfig",
name="bi_encoder_query_encode_config",
field=models.JSONField(default=dict),
),
migrations.AlterField(
model_name="searchmodelconfig",
name="cross_encoder",
field=models.CharField(default="mixedbread-ai/mxbai-rerank-xsmall-v1", max_length=200),
),
]

View File

@@ -179,13 +179,27 @@ class SearchModelConfig(BaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
TEXT = "text" TEXT = "text"
# This is the model name exposed to users on their settings page
name = models.CharField(max_length=200, default="default") name = models.CharField(max_length=200, default="default")
# Type of content the model can generate embeddings for
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
# Bi-encoder model of sentence-transformer type to load from HuggingFace
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") # Config passed to the sentence-transformer model constructor. E.g device="cuda:0", trust_remote_server=True etc.
bi_encoder_model_config = models.JSONField(default=dict)
# Query encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models
bi_encoder_query_encode_config = models.JSONField(default=dict)
# Docs encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models
bi_encoder_docs_encode_config = models.JSONField(default=dict)
# Cross-encoder model of sentence-transformer type to load from HuggingFace
cross_encoder = models.CharField(max_length=200, default="mixedbread-ai/mxbai-rerank-xsmall-v1")
# Inference server API endpoint to use for embeddings inference. Bi-encoder model should be hosted on this server
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)

View File

@@ -13,7 +13,7 @@ from tenacity import (
) )
from torch import nn from torch import nn
from khoj.utils.helpers import get_device from khoj.utils.helpers import get_device, merge_dicts
from khoj.utils.rawconfig import SearchResponse from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,9 +25,15 @@ class EmbeddingsModel:
model_name: str = "thenlper/gte-small", model_name: str = "thenlper/gte-small",
embeddings_inference_endpoint: str = None, embeddings_inference_endpoint: str = None,
embeddings_inference_endpoint_api_key: str = None, embeddings_inference_endpoint_api_key: str = None,
query_encode_kwargs: dict = {},
docs_encode_kwargs: dict = {},
model_kwargs: dict = {},
): ):
self.encode_kwargs = {"normalize_embeddings": True} default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
self.model_kwargs = {"device": get_device()} default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
self.model_name = model_name self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key self.api_key = embeddings_inference_endpoint_api_key
@@ -39,7 +45,7 @@ class EmbeddingsModel:
def embed_query(self, query): def embed_query(self, query):
if self.inference_server_enabled(): if self.inference_server_enabled():
return self.embed_with_api([query])[0] return self.embed_with_api([query])[0]
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
@retry( @retry(
retry=retry_if_exception_type(requests.exceptions.HTTPError), retry=retry_if_exception_type(requests.exceptions.HTTPError),
@@ -70,7 +76,7 @@ class EmbeddingsModel:
logger.warning( logger.warning(
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead." f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
) )
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits # break up the docs payload in chunks of 1000 to avoid hitting rate limits
embeddings = [] embeddings = []
with tqdm.tqdm(total=len(docs)) as pbar: with tqdm.tqdm(total=len(docs)) as pbar:
@@ -80,13 +86,13 @@ class EmbeddingsModel:
embeddings += generated_embeddings embeddings += generated_embeddings
pbar.update(1000) pbar.update(1000)
return embeddings return embeddings
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
class CrossEncoderModel: class CrossEncoderModel:
def __init__( def __init__(
self, self,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", model_name: str = "mixedbread-ai/mxbai-rerank-xsmall-v1",
cross_encoder_inference_endpoint: str = None, cross_encoder_inference_endpoint: str = None,
cross_encoder_inference_endpoint_api_key: str = None, cross_encoder_inference_endpoint_api_key: str = None,
): ):

View File

@@ -75,7 +75,7 @@ class SearchModelFactory(factory.django.DjangoModelFactory):
name = "default" name = "default"
model_type = "text" model_type = "text"
bi_encoder = "thenlper/gte-small" bi_encoder = "thenlper/gte-small"
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2" cross_encoder = "mixedbread-ai/mxbai-rerank-xsmall-v1"
class SubscriptionFactory(factory.django.DjangoModelFactory): class SubscriptionFactory(factory.django.DjangoModelFactory):