mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Support more embedding models by making query, docs encoding configurable
Most newer, better embeddings models add a query, docs prefix when encoding. Previously Khoj admins couldn't configure these, so it wasn't possible to use these newer models. This change allows configuring the kwargs passed to the query, docs encoders by updating the search config in the database.
This commit is contained in:
@@ -216,6 +216,9 @@ def configure_server(
|
||||
model.bi_encoder,
|
||||
model.embeddings_inference_endpoint,
|
||||
model.embeddings_inference_endpoint_api_key,
|
||||
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||
model_kwargs=model.bi_encoder_model_config,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -182,6 +182,9 @@ class SearchModelConfig(BaseModel):
|
||||
name = models.CharField(max_length=200, default="default")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
|
||||
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
|
||||
bi_encoder_model_config = models.JSONField(default=dict)
|
||||
bi_encoder_query_encode_config = models.JSONField(default=dict)
|
||||
bi_encoder_docs_encode_config = models.JSONField(default=dict)
|
||||
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
|
||||
@@ -13,7 +13,7 @@ from tenacity import (
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from khoj.utils.helpers import get_device
|
||||
from khoj.utils.helpers import get_device, merge_dicts
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,9 +25,15 @@ class EmbeddingsModel:
|
||||
model_name: str = "thenlper/gte-small",
|
||||
embeddings_inference_endpoint: str = None,
|
||||
embeddings_inference_endpoint_api_key: str = None,
|
||||
query_encode_kwargs: dict = {},
|
||||
docs_encode_kwargs: dict = {},
|
||||
model_kwargs: dict = {},
|
||||
):
|
||||
self.encode_kwargs = {"normalize_embeddings": True}
|
||||
self.model_kwargs = {"device": get_device()}
|
||||
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
|
||||
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
|
||||
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
|
||||
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
|
||||
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
|
||||
self.model_name = model_name
|
||||
self.inference_endpoint = embeddings_inference_endpoint
|
||||
self.api_key = embeddings_inference_endpoint_api_key
|
||||
@@ -39,7 +45,7 @@ class EmbeddingsModel:
|
||||
def embed_query(self, query):
|
||||
if self.inference_server_enabled():
|
||||
return self.embed_with_api([query])[0]
|
||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(requests.exceptions.HTTPError),
|
||||
@@ -70,7 +76,7 @@ class EmbeddingsModel:
|
||||
logger.warning(
|
||||
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
|
||||
)
|
||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
|
||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||
embeddings = []
|
||||
with tqdm.tqdm(total=len(docs)) as pbar:
|
||||
@@ -80,7 +86,7 @@ class EmbeddingsModel:
|
||||
embeddings += generated_embeddings
|
||||
pbar.update(1000)
|
||||
return embeddings
|
||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
|
||||
|
||||
|
||||
class CrossEncoderModel:
|
||||
|
||||
Reference in New Issue
Block a user