mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 13:21:18 +00:00
Support using Embeddings Model exposed via OpenAI (compatible) API (#1051)
This change adds the ability to use OpenAI, Azure OpenAI or any embedding model exposed behind an OpenAI compatible API (like Ollama, LiteLLM, vLLM etc.). Khoj previously only supported HuggingFace embedding models running locally on device or via HuggingFaceW inference API endpoint. This allows using commercial embedding models to index your content with Khoj.
This commit is contained in:
@@ -249,6 +249,7 @@ def configure_server(
|
||||
model.bi_encoder,
|
||||
model.embeddings_inference_endpoint,
|
||||
model.embeddings_inference_endpoint_api_key,
|
||||
model.embeddings_inference_endpoint_type,
|
||||
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||
model_kwargs=model.bi_encoder_model_config,
|
||||
|
||||
@@ -3,11 +3,11 @@ from typing import List
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import transaction
|
||||
from django.db.models import Count, Q
|
||||
from django.db.models import Q
|
||||
from tqdm import tqdm
|
||||
|
||||
from khoj.database.adapters import get_default_search_model
|
||||
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig
|
||||
from khoj.database.models import Entry, SearchModelConfig
|
||||
from khoj.processor.embeddings import EmbeddingsModel
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -74,6 +74,7 @@ class Command(BaseCommand):
|
||||
model.bi_encoder,
|
||||
model.embeddings_inference_endpoint,
|
||||
model.embeddings_inference_endpoint_api_key,
|
||||
model.embeddings_inference_endpoint_type,
|
||||
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||
model_kwargs=model.bi_encoder_model_config,
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
# Generated by Django 5.0.10 on 2025-01-08 15:09
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
def set_endpoint_type(apps, schema_editor):
|
||||
SearchModelConfig = apps.get_model("database", "SearchModelConfig")
|
||||
SearchModelConfig.objects.filter(embeddings_inference_endpoint__isnull=False).exclude(
|
||||
embeddings_inference_endpoint=""
|
||||
).update(embeddings_inference_endpoint_type="huggingface")
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0078_khojuser_email_verification_code_expiry"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="searchmodelconfig",
|
||||
name="embeddings_inference_endpoint_type",
|
||||
field=models.CharField(
|
||||
choices=[("huggingface", "Huggingface"), ("openai", "Openai"), ("local", "Local")],
|
||||
default="local",
|
||||
max_length=200,
|
||||
),
|
||||
),
|
||||
migrations.RunPython(set_endpoint_type, reverse_code=migrations.RunPython.noop),
|
||||
]
|
||||
@@ -481,6 +481,11 @@ class SearchModelConfig(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
TEXT = "text"
|
||||
|
||||
class ApiType(models.TextChoices):
|
||||
HUGGINGFACE = "huggingface"
|
||||
OPENAI = "openai"
|
||||
LOCAL = "local"
|
||||
|
||||
# This is the model name exposed to users on their settings page
|
||||
name = models.CharField(max_length=200, default="default")
|
||||
# Type of content the model can generate embeddings for
|
||||
@@ -501,6 +506,10 @@ class SearchModelConfig(DbBaseModel):
|
||||
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
|
||||
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
# Inference server API type to use for embeddings inference.
|
||||
embeddings_inference_endpoint_type = models.CharField(
|
||||
max_length=200, choices=ApiType.choices, default=ApiType.LOCAL
|
||||
)
|
||||
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
|
||||
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import openai
|
||||
import requests
|
||||
import tqdm
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||
@@ -13,7 +15,14 @@ from tenacity import (
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
|
||||
from khoj.database.models import SearchModelConfig
|
||||
from khoj.utils.helpers import (
|
||||
fix_json_dict,
|
||||
get_device,
|
||||
get_openai_client,
|
||||
merge_dicts,
|
||||
timer,
|
||||
)
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,6 +34,7 @@ class EmbeddingsModel:
|
||||
model_name: str = "thenlper/gte-small",
|
||||
embeddings_inference_endpoint: str = None,
|
||||
embeddings_inference_endpoint_api_key: str = None,
|
||||
embeddings_inference_endpoint_type=SearchModelConfig.ApiType.LOCAL,
|
||||
query_encode_kwargs: dict = {},
|
||||
docs_encode_kwargs: dict = {},
|
||||
model_kwargs: dict = {},
|
||||
@@ -37,15 +47,16 @@ class EmbeddingsModel:
|
||||
self.model_name = model_name
|
||||
self.inference_endpoint = embeddings_inference_endpoint
|
||||
self.api_key = embeddings_inference_endpoint_api_key
|
||||
with timer(f"Loaded embedding model {self.model_name}", logger):
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
|
||||
def inference_server_enabled(self) -> bool:
|
||||
return self.api_key is not None and self.inference_endpoint is not None
|
||||
self.inference_endpoint_type = embeddings_inference_endpoint_type
|
||||
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
|
||||
with timer(f"Loaded embedding model {self.model_name}", logger):
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
|
||||
def embed_query(self, query):
|
||||
if self.inference_server_enabled():
|
||||
return self.embed_with_api([query])[0]
|
||||
if self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
|
||||
return self.embed_with_hf([query])[0]
|
||||
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
|
||||
return self.embed_with_openai([query])[0]
|
||||
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
|
||||
|
||||
@retry(
|
||||
@@ -54,7 +65,7 @@ class EmbeddingsModel:
|
||||
stop=stop_after_attempt(5),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
)
|
||||
def embed_with_api(self, docs):
|
||||
def embed_with_hf(self, docs):
|
||||
payload = {"inputs": docs}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
@@ -71,23 +82,38 @@ class EmbeddingsModel:
|
||||
raise e
|
||||
return response.json()["embeddings"]
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(requests.exceptions.HTTPError),
|
||||
wait=wait_random_exponential(multiplier=1, max=10),
|
||||
stop=stop_after_attempt(5),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
)
|
||||
def embed_with_openai(self, docs):
|
||||
client = get_openai_client(self.api_key, self.inference_endpoint)
|
||||
response = client.embeddings.create(input=docs, model=self.model_name, encoding_format="float")
|
||||
return [item.embedding for item in response.data]
|
||||
|
||||
def embed_documents(self, docs):
|
||||
if self.inference_server_enabled():
|
||||
if "huggingface" not in self.inference_endpoint:
|
||||
logger.warning(
|
||||
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
|
||||
)
|
||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
|
||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||
embeddings = []
|
||||
with tqdm.tqdm(total=len(docs)) as pbar:
|
||||
for i in range(0, len(docs), 1000):
|
||||
docs_to_embed = docs[i : i + 1000]
|
||||
generated_embeddings = self.embed_with_api(docs_to_embed)
|
||||
embeddings += generated_embeddings
|
||||
pbar.update(1000)
|
||||
return embeddings
|
||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
|
||||
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
|
||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
|
||||
elif self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
|
||||
embed_with_api = self.embed_with_hf
|
||||
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
|
||||
embed_with_api = self.embed_with_openai
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unsupported inference endpoint: {self.inference_endpoint_type}. Generating embeddings locally instead."
|
||||
)
|
||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
|
||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||
embeddings = []
|
||||
with tqdm.tqdm(total=len(docs)) as pbar:
|
||||
for i in range(0, len(docs), 1000):
|
||||
docs_to_embed = docs[i : i + 1000]
|
||||
generated_embeddings = embed_with_api(docs_to_embed)
|
||||
embeddings += generated_embeddings
|
||||
pbar.update(1000)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CrossEncoderModel:
|
||||
|
||||
Reference in New Issue
Block a user