Support using Embeddings Model exposed via OpenAI (compatible) API

This commit is contained in:
Debanjum
2025-01-09 02:50:52 +07:00
parent 65f1c27963
commit 1b5826d8b6
5 changed files with 93 additions and 27 deletions

View File

@@ -249,6 +249,7 @@ def configure_server(
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
model.embeddings_inference_endpoint_type,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,

View File

@@ -3,11 +3,11 @@ from typing import List
from django.core.management.base import BaseCommand
from django.db import transaction
from django.db.models import Count, Q
from django.db.models import Q
from tqdm import tqdm
from khoj.database.adapters import get_default_search_model
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig
from khoj.database.models import Entry, SearchModelConfig
from khoj.processor.embeddings import EmbeddingsModel
logging.basicConfig(level=logging.INFO)
@@ -74,6 +74,7 @@ class Command(BaseCommand):
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
model.embeddings_inference_endpoint_type,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,

View File

@@ -0,0 +1,29 @@
# Generated by Django 5.0.10 on 2025-01-08 15:09
from django.db import migrations, models
def set_endpoint_type(apps, schema_editor):
SearchModelConfig = apps.get_model("database", "SearchModelConfig")
SearchModelConfig.objects.filter(embeddings_inference_endpoint__isnull=False).exclude(
embeddings_inference_endpoint=""
).update(embeddings_inference_endpoint_type="huggingface")
class Migration(migrations.Migration):
dependencies = [
("database", "0078_khojuser_email_verification_code_expiry"),
]
operations = [
migrations.AddField(
model_name="searchmodelconfig",
name="embeddings_inference_endpoint_type",
field=models.CharField(
choices=[("huggingface", "Huggingface"), ("openai", "Openai"), ("local", "Local")],
default="local",
max_length=200,
),
),
migrations.RunPython(set_endpoint_type, reverse_code=migrations.RunPython.noop),
]

View File

@@ -481,6 +481,11 @@ class SearchModelConfig(DbBaseModel):
class ModelType(models.TextChoices):
TEXT = "text"
class ApiType(models.TextChoices):
HUGGINGFACE = "huggingface"
OPENAI = "openai"
LOCAL = "local"
# This is the model name exposed to users on their settings page
name = models.CharField(max_length=200, default="default")
# Type of content the model can generate embeddings for
@@ -501,6 +506,10 @@ class SearchModelConfig(DbBaseModel):
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API type to use for embeddings inference.
embeddings_inference_endpoint_type = models.CharField(
max_length=200, choices=ApiType.choices, default=ApiType.LOCAL
)
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server

View File

@@ -1,6 +1,8 @@
import logging
from typing import List
from urllib.parse import urlparse
import openai
import requests
import tqdm
from sentence_transformers import CrossEncoder, SentenceTransformer
@@ -13,7 +15,14 @@ from tenacity import (
)
from torch import nn
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
from khoj.database.models import SearchModelConfig
from khoj.utils.helpers import (
fix_json_dict,
get_device,
get_openai_client,
merge_dicts,
timer,
)
from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__)
@@ -25,6 +34,7 @@ class EmbeddingsModel:
model_name: str = "thenlper/gte-small",
embeddings_inference_endpoint: str = None,
embeddings_inference_endpoint_api_key: str = None,
embeddings_inference_endpoint_type=SearchModelConfig.ApiType.LOCAL,
query_encode_kwargs: dict = {},
docs_encode_kwargs: dict = {},
model_kwargs: dict = {},
@@ -37,15 +47,16 @@ class EmbeddingsModel:
self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key
with timer(f"Loaded embedding model {self.model_name}", logger):
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def inference_server_enabled(self) -> bool:
return self.api_key is not None and self.inference_endpoint is not None
self.inference_endpoint_type = embeddings_inference_endpoint_type
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
with timer(f"Loaded embedding model {self.model_name}", logger):
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def embed_query(self, query):
if self.inference_server_enabled():
return self.embed_with_api([query])[0]
if self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
return self.embed_with_hf([query])[0]
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
return self.embed_with_openai([query])[0]
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
@retry(
@@ -54,7 +65,7 @@ class EmbeddingsModel:
stop=stop_after_attempt(5),
before_sleep=before_sleep_log(logger, logging.DEBUG),
)
def embed_with_api(self, docs):
def embed_with_hf(self, docs):
payload = {"inputs": docs}
headers = {
"Authorization": f"Bearer {self.api_key}",
@@ -71,23 +82,38 @@ class EmbeddingsModel:
raise e
return response.json()["embeddings"]
@retry(
retry=retry_if_exception_type(requests.exceptions.HTTPError),
wait=wait_random_exponential(multiplier=1, max=10),
stop=stop_after_attempt(5),
before_sleep=before_sleep_log(logger, logging.DEBUG),
)
def embed_with_openai(self, docs):
client = get_openai_client(self.api_key, self.inference_endpoint)
response = client.embeddings.create(input=docs, model=self.model_name, encoding_format="float")
return [item.embedding for item in response.data]
def embed_documents(self, docs):
if self.inference_server_enabled():
if "huggingface" not in self.inference_endpoint:
logger.warning(
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
)
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
embeddings = []
with tqdm.tqdm(total=len(docs)) as pbar:
for i in range(0, len(docs), 1000):
docs_to_embed = docs[i : i + 1000]
generated_embeddings = self.embed_with_api(docs_to_embed)
embeddings += generated_embeddings
pbar.update(1000)
return embeddings
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
elif self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
embed_with_api = self.embed_with_hf
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
embed_with_api = self.embed_with_openai
else:
logger.warning(
f"Unsupported inference endpoint: {self.inference_endpoint_type}. Generating embeddings locally instead."
)
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
embeddings = []
with tqdm.tqdm(total=len(docs)) as pbar:
for i in range(0, len(docs), 1000):
docs_to_embed = docs[i : i + 1000]
generated_embeddings = embed_with_api(docs_to_embed)
embeddings += generated_embeddings
pbar.update(1000)
return embeddings
class CrossEncoderModel: