Support using Embeddings Model exposed via OpenAI (compatible) API

This commit is contained in:
Debanjum
2025-01-09 02:50:52 +07:00
parent 65f1c27963
commit 1b5826d8b6
5 changed files with 93 additions and 27 deletions

View File

@@ -249,6 +249,7 @@ def configure_server(
model.bi_encoder, model.bi_encoder,
model.embeddings_inference_endpoint, model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key, model.embeddings_inference_endpoint_api_key,
model.embeddings_inference_endpoint_type,
query_encode_kwargs=model.bi_encoder_query_encode_config, query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config, docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config, model_kwargs=model.bi_encoder_model_config,

View File

@@ -3,11 +3,11 @@ from typing import List
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.db import transaction from django.db import transaction
from django.db.models import Count, Q from django.db.models import Q
from tqdm import tqdm from tqdm import tqdm
from khoj.database.adapters import get_default_search_model from khoj.database.adapters import get_default_search_model
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig from khoj.database.models import Entry, SearchModelConfig
from khoj.processor.embeddings import EmbeddingsModel from khoj.processor.embeddings import EmbeddingsModel
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@@ -74,6 +74,7 @@ class Command(BaseCommand):
model.bi_encoder, model.bi_encoder,
model.embeddings_inference_endpoint, model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key, model.embeddings_inference_endpoint_api_key,
model.embeddings_inference_endpoint_type,
query_encode_kwargs=model.bi_encoder_query_encode_config, query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config, docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config, model_kwargs=model.bi_encoder_model_config,

View File

@@ -0,0 +1,29 @@
# Generated by Django 5.0.10 on 2025-01-08 15:09
from django.db import migrations, models
def set_endpoint_type(apps, schema_editor):
SearchModelConfig = apps.get_model("database", "SearchModelConfig")
SearchModelConfig.objects.filter(embeddings_inference_endpoint__isnull=False).exclude(
embeddings_inference_endpoint=""
).update(embeddings_inference_endpoint_type="huggingface")
class Migration(migrations.Migration):
dependencies = [
("database", "0078_khojuser_email_verification_code_expiry"),
]
operations = [
migrations.AddField(
model_name="searchmodelconfig",
name="embeddings_inference_endpoint_type",
field=models.CharField(
choices=[("huggingface", "Huggingface"), ("openai", "Openai"), ("local", "Local")],
default="local",
max_length=200,
),
),
migrations.RunPython(set_endpoint_type, reverse_code=migrations.RunPython.noop),
]

View File

@@ -481,6 +481,11 @@ class SearchModelConfig(DbBaseModel):
class ModelType(models.TextChoices): class ModelType(models.TextChoices):
TEXT = "text" TEXT = "text"
class ApiType(models.TextChoices):
HUGGINGFACE = "huggingface"
OPENAI = "openai"
LOCAL = "local"
# This is the model name exposed to users on their settings page # This is the model name exposed to users on their settings page
name = models.CharField(max_length=200, default="default") name = models.CharField(max_length=200, default="default")
# Type of content the model can generate embeddings for # Type of content the model can generate embeddings for
@@ -501,6 +506,10 @@ class SearchModelConfig(DbBaseModel):
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server # Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API type to use for embeddings inference.
embeddings_inference_endpoint_type = models.CharField(
max_length=200, choices=ApiType.choices, default=ApiType.LOCAL
)
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server # Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server # Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server

View File

@@ -1,6 +1,8 @@
import logging import logging
from typing import List from typing import List
from urllib.parse import urlparse
import openai
import requests import requests
import tqdm import tqdm
from sentence_transformers import CrossEncoder, SentenceTransformer from sentence_transformers import CrossEncoder, SentenceTransformer
@@ -13,7 +15,14 @@ from tenacity import (
) )
from torch import nn from torch import nn
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer from khoj.database.models import SearchModelConfig
from khoj.utils.helpers import (
fix_json_dict,
get_device,
get_openai_client,
merge_dicts,
timer,
)
from khoj.utils.rawconfig import SearchResponse from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,6 +34,7 @@ class EmbeddingsModel:
model_name: str = "thenlper/gte-small", model_name: str = "thenlper/gte-small",
embeddings_inference_endpoint: str = None, embeddings_inference_endpoint: str = None,
embeddings_inference_endpoint_api_key: str = None, embeddings_inference_endpoint_api_key: str = None,
embeddings_inference_endpoint_type=SearchModelConfig.ApiType.LOCAL,
query_encode_kwargs: dict = {}, query_encode_kwargs: dict = {},
docs_encode_kwargs: dict = {}, docs_encode_kwargs: dict = {},
model_kwargs: dict = {}, model_kwargs: dict = {},
@@ -37,15 +47,16 @@ class EmbeddingsModel:
self.model_name = model_name self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key self.api_key = embeddings_inference_endpoint_api_key
with timer(f"Loaded embedding model {self.model_name}", logger): self.inference_endpoint_type = embeddings_inference_endpoint_type
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
with timer(f"Loaded embedding model {self.model_name}", logger):
def inference_server_enabled(self) -> bool: self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
return self.api_key is not None and self.inference_endpoint is not None
def embed_query(self, query): def embed_query(self, query):
if self.inference_server_enabled(): if self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
return self.embed_with_api([query])[0] return self.embed_with_hf([query])[0]
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
return self.embed_with_openai([query])[0]
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0] return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
@retry( @retry(
@@ -54,7 +65,7 @@ class EmbeddingsModel:
stop=stop_after_attempt(5), stop=stop_after_attempt(5),
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
) )
def embed_with_api(self, docs): def embed_with_hf(self, docs):
payload = {"inputs": docs} payload = {"inputs": docs}
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
@@ -71,23 +82,38 @@ class EmbeddingsModel:
raise e raise e
return response.json()["embeddings"] return response.json()["embeddings"]
@retry(
retry=retry_if_exception_type(requests.exceptions.HTTPError),
wait=wait_random_exponential(multiplier=1, max=10),
stop=stop_after_attempt(5),
before_sleep=before_sleep_log(logger, logging.DEBUG),
)
def embed_with_openai(self, docs):
client = get_openai_client(self.api_key, self.inference_endpoint)
response = client.embeddings.create(input=docs, model=self.model_name, encoding_format="float")
return [item.embedding for item in response.data]
def embed_documents(self, docs): def embed_documents(self, docs):
if self.inference_server_enabled(): if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
if "huggingface" not in self.inference_endpoint: return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
logger.warning( elif self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead." embed_with_api = self.embed_with_hf
) elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() embed_with_api = self.embed_with_openai
# break up the docs payload in chunks of 1000 to avoid hitting rate limits else:
embeddings = [] logger.warning(
with tqdm.tqdm(total=len(docs)) as pbar: f"Unsupported inference endpoint: {self.inference_endpoint_type}. Generating embeddings locally instead."
for i in range(0, len(docs), 1000): )
docs_to_embed = docs[i : i + 1000] return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
generated_embeddings = self.embed_with_api(docs_to_embed) # break up the docs payload in chunks of 1000 to avoid hitting rate limits
embeddings += generated_embeddings embeddings = []
pbar.update(1000) with tqdm.tqdm(total=len(docs)) as pbar:
return embeddings for i in range(0, len(docs), 1000):
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else [] docs_to_embed = docs[i : i + 1000]
generated_embeddings = embed_with_api(docs_to_embed)
embeddings += generated_embeddings
pbar.update(1000)
return embeddings
class CrossEncoderModel: class CrossEncoderModel: