mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 13:26:13 +00:00
Make Search Models More Configurable. Upgrade Default Cross-Encoder (#722)
- Upgrade default cross-encoder to mixedbread ai's mxbai-rerank-xsmall - Support more embedding models by making query, docs encoding configurable
This commit is contained in:
@@ -7,9 +7,11 @@ sidebar_position: 3
|
|||||||
## Search across Different Languages (Self-Hosting)
|
## Search across Different Languages (Self-Hosting)
|
||||||
To search for notes in multiple, different languages, you can use a [multi-lingual model](https://www.sbert.net/docs/pretrained_models.html#multi-lingual-models).<br />
|
To search for notes in multiple, different languages, you can use a [multi-lingual model](https://www.sbert.net/docs/pretrained_models.html#multi-lingual-models).<br />
|
||||||
For example, the [paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2) supports [50+ languages](https://www.sbert.net/docs/pretrained_models.html#:~:text=we%20used%20the%20following%2050%2B%20languages), has good search quality and speed. To use it:
|
For example, the [paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2) supports [50+ languages](https://www.sbert.net/docs/pretrained_models.html#:~:text=we%20used%20the%20following%2050%2B%20languages), has good search quality and speed. To use it:
|
||||||
1. Manually update the search config in server's admin settings page. Go to [the search config](http://localhost:42110/server/admin/database/searchmodelconfig/). Either create a new one, if none exists, or update the existing one. Set the bi_encoder to `sentence-transformers/multi-qa-MiniLM-L6-cos-v1` and the cross_encoder to `cross-encoder/ms-marco-MiniLM-L-6-v2`.
|
1. Manually update the search config in server's admin settings page. Go to [the search config](http://localhost:42110/server/admin/database/searchmodelconfig/). Either create a new one, if none exists, or update the existing one. Set the bi_encoder to `sentence-transformers/multi-qa-MiniLM-L6-cos-v1` and the cross_encoder to `mixedbread-ai/mxbai-rerank-xsmall-v1`.
|
||||||
2. Regenerate your content index from all the relevant clients. This step is very important, as you'll need to re-encode all your content with the new model.
|
2. Regenerate your content index from all the relevant clients. This step is very important, as you'll need to re-encode all your content with the new model.
|
||||||
|
|
||||||
|
Note: If you use a search model that expects a prefix (e.g [mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)) to the query (or docs) string before encoding. Update the `bi_encoder_query_encode_config` field with `{prompt: <prefix-prompt>}`. Eg. `{prompt: "Represent this query for searching documents"}`. You can pass a valid JSON object that the SentenceTransformer `encode` function accepts
|
||||||
|
|
||||||
## Query Filters
|
## Query Filters
|
||||||
|
|
||||||
Use structured query syntax to filter entries from your knowledge based used by search results or chat responses.
|
Use structured query syntax to filter entries from your knowledge based used by search results or chat responses.
|
||||||
|
|||||||
@@ -216,6 +216,9 @@ def configure_server(
|
|||||||
model.bi_encoder,
|
model.bi_encoder,
|
||||||
model.embeddings_inference_endpoint,
|
model.embeddings_inference_endpoint,
|
||||||
model.embeddings_inference_endpoint_api_key,
|
model.embeddings_inference_endpoint_api_key,
|
||||||
|
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||||
|
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||||
|
model_kwargs=model.bi_encoder_model_config,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
# Generated by Django 4.2.10 on 2024-04-24 04:19
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0036_delete_offlinechatprocessorconversationconfig"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="searchmodelconfig",
|
||||||
|
name="bi_encoder_docs_encode_config",
|
||||||
|
field=models.JSONField(default=dict),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="searchmodelconfig",
|
||||||
|
name="bi_encoder_model_config",
|
||||||
|
field=models.JSONField(default=dict),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="searchmodelconfig",
|
||||||
|
name="bi_encoder_query_encode_config",
|
||||||
|
field=models.JSONField(default=dict),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="searchmodelconfig",
|
||||||
|
name="cross_encoder",
|
||||||
|
field=models.CharField(default="mixedbread-ai/mxbai-rerank-xsmall-v1", max_length=200),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -179,13 +179,27 @@ class SearchModelConfig(BaseModel):
|
|||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
|
|
||||||
|
# This is the model name exposed to users on their settings page
|
||||||
name = models.CharField(max_length=200, default="default")
|
name = models.CharField(max_length=200, default="default")
|
||||||
|
# Type of content the model can generate embeddings for
|
||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
|
||||||
|
# Bi-encoder model of sentence-transformer type to load from HuggingFace
|
||||||
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
|
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
|
||||||
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
# Config passed to the sentence-transformer model constructor. E.g device="cuda:0", trust_remote_server=True etc.
|
||||||
|
bi_encoder_model_config = models.JSONField(default=dict)
|
||||||
|
# Query encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models
|
||||||
|
bi_encoder_query_encode_config = models.JSONField(default=dict)
|
||||||
|
# Docs encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models
|
||||||
|
bi_encoder_docs_encode_config = models.JSONField(default=dict)
|
||||||
|
# Cross-encoder model of sentence-transformer type to load from HuggingFace
|
||||||
|
cross_encoder = models.CharField(max_length=200, default="mixedbread-ai/mxbai-rerank-xsmall-v1")
|
||||||
|
# Inference server API endpoint to use for embeddings inference. Bi-encoder model should be hosted on this server
|
||||||
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
|
||||||
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
|
||||||
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
|
||||||
cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from tenacity import (
|
|||||||
)
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from khoj.utils.helpers import get_device
|
from khoj.utils.helpers import get_device, merge_dicts
|
||||||
from khoj.utils.rawconfig import SearchResponse
|
from khoj.utils.rawconfig import SearchResponse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -25,9 +25,15 @@ class EmbeddingsModel:
|
|||||||
model_name: str = "thenlper/gte-small",
|
model_name: str = "thenlper/gte-small",
|
||||||
embeddings_inference_endpoint: str = None,
|
embeddings_inference_endpoint: str = None,
|
||||||
embeddings_inference_endpoint_api_key: str = None,
|
embeddings_inference_endpoint_api_key: str = None,
|
||||||
|
query_encode_kwargs: dict = {},
|
||||||
|
docs_encode_kwargs: dict = {},
|
||||||
|
model_kwargs: dict = {},
|
||||||
):
|
):
|
||||||
self.encode_kwargs = {"normalize_embeddings": True}
|
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
|
||||||
self.model_kwargs = {"device": get_device()}
|
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
|
||||||
|
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
|
||||||
|
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
|
||||||
|
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.inference_endpoint = embeddings_inference_endpoint
|
self.inference_endpoint = embeddings_inference_endpoint
|
||||||
self.api_key = embeddings_inference_endpoint_api_key
|
self.api_key = embeddings_inference_endpoint_api_key
|
||||||
@@ -39,7 +45,7 @@ class EmbeddingsModel:
|
|||||||
def embed_query(self, query):
|
def embed_query(self, query):
|
||||||
if self.inference_server_enabled():
|
if self.inference_server_enabled():
|
||||||
return self.embed_with_api([query])[0]
|
return self.embed_with_api([query])[0]
|
||||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception_type(requests.exceptions.HTTPError),
|
retry=retry_if_exception_type(requests.exceptions.HTTPError),
|
||||||
@@ -70,7 +76,7 @@ class EmbeddingsModel:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
|
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
|
||||||
)
|
)
|
||||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
|
||||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||||
embeddings = []
|
embeddings = []
|
||||||
with tqdm.tqdm(total=len(docs)) as pbar:
|
with tqdm.tqdm(total=len(docs)) as pbar:
|
||||||
@@ -80,13 +86,13 @@ class EmbeddingsModel:
|
|||||||
embeddings += generated_embeddings
|
embeddings += generated_embeddings
|
||||||
pbar.update(1000)
|
pbar.update(1000)
|
||||||
return embeddings
|
return embeddings
|
||||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
|
||||||
|
|
||||||
|
|
||||||
class CrossEncoderModel:
|
class CrossEncoderModel:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
model_name: str = "mixedbread-ai/mxbai-rerank-xsmall-v1",
|
||||||
cross_encoder_inference_endpoint: str = None,
|
cross_encoder_inference_endpoint: str = None,
|
||||||
cross_encoder_inference_endpoint_api_key: str = None,
|
cross_encoder_inference_endpoint_api_key: str = None,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class SearchModelFactory(factory.django.DjangoModelFactory):
|
|||||||
name = "default"
|
name = "default"
|
||||||
model_type = "text"
|
model_type = "text"
|
||||||
bi_encoder = "thenlper/gte-small"
|
bi_encoder = "thenlper/gte-small"
|
||||||
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
cross_encoder = "mixedbread-ai/mxbai-rerank-xsmall-v1"
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
||||||
|
|||||||
Reference in New Issue
Block a user