Support more embedding models by making query, docs encoding configurable

Most newer, better embeddings models add a query, docs prefix when
encoding. Previously Khoj admins couldn't configure these, so it
wasn't possible to use these newer models.

This change allows configuring the kwargs passed to the query, docs
encoders by updating the search config in the database.
This commit is contained in:
Debanjum Singh Solanky
2024-04-24 09:02:20 +05:30
parent 8e77b3dc82
commit 7eaf9367fe
3 changed files with 18 additions and 6 deletions

View File

@@ -216,6 +216,9 @@ def configure_server(
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,
)
}
)

View File

@@ -182,6 +182,9 @@ class SearchModelConfig(BaseModel):
name = models.CharField(max_length=200, default="default")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
bi_encoder_model_config = models.JSONField(default=dict)
bi_encoder_query_encode_config = models.JSONField(default=dict)
bi_encoder_docs_encode_config = models.JSONField(default=dict)
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)

View File

@@ -13,7 +13,7 @@ from tenacity import (
)
from torch import nn
from khoj.utils.helpers import get_device
from khoj.utils.helpers import get_device, merge_dicts
from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__)
@@ -25,9 +25,15 @@ class EmbeddingsModel:
model_name: str = "thenlper/gte-small",
embeddings_inference_endpoint: str = None,
embeddings_inference_endpoint_api_key: str = None,
query_encode_kwargs: dict = {},
docs_encode_kwargs: dict = {},
model_kwargs: dict = {},
):
self.encode_kwargs = {"normalize_embeddings": True}
self.model_kwargs = {"device": get_device()}
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key
@@ -39,7 +45,7 @@ class EmbeddingsModel:
def embed_query(self, query):
if self.inference_server_enabled():
return self.embed_with_api([query])[0]
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
@retry(
retry=retry_if_exception_type(requests.exceptions.HTTPError),
@@ -70,7 +76,7 @@ class EmbeddingsModel:
logger.warning(
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
)
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
embeddings = []
with tqdm.tqdm(total=len(docs)) as pbar:
@@ -80,7 +86,7 @@ class EmbeddingsModel:
embeddings += generated_embeddings
pbar.update(1000)
return embeddings
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
class CrossEncoderModel: