mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Support using Embeddings Model exposed via OpenAI (compatible) API (#1051)
This change adds the ability to use OpenAI, Azure OpenAI or any embedding model exposed behind an OpenAI compatible API (like Ollama, LiteLLM, vLLM etc.). Khoj previously only supported HuggingFace embedding models running locally on device or via HuggingFaceW inference API endpoint. This allows using commercial embedding models to index your content with Khoj.
This commit is contained in:
@@ -15,3 +15,37 @@ Take advantage of super fast search to find relevant notes and documents from yo
|
|||||||
|
|
||||||
### Demo
|
### Demo
|
||||||

|

|
||||||
|
|
||||||
|
|
||||||
|
### Implementation Overview
|
||||||
|
A bi-encoder models is used to create meaning vectors (aka vector embeddings) of your documents and search queries.
|
||||||
|
1. When you sync you documents with Khoj, it uses the bi-encoder model to create and store meaning vectors of (chunks of) your documents
|
||||||
|
2. When you initiate a natural language search the bi-encoder model converts your query into a meaning vector and finds the most relevant document chunks for that query by comparing their meaning vectors.
|
||||||
|
3. The slower but higher-quality cross-encoder model is than used to re-rank these documents for your given query.
|
||||||
|
|
||||||
|
### Setup (Self-Hosting)
|
||||||
|
You are **not required** to configure the search model config when self-hosting. Khoj sets up decent default local search model config for general use.
|
||||||
|
|
||||||
|
You may want to configure this if you need better multi-lingual search, want to experiment with different, newer models or the default models do not work for your use-case.
|
||||||
|
|
||||||
|
You can use bi-encoder models downloaded locally [from Huggingface](https://huggingface.co/models?library=sentence-transformers), served via the [HuggingFace Inference API](https://endpoints.huggingface.co/), OpenAI API, Azure OpenAI API or any OpenAI Compatible API like Ollama, LiteLLM etc. Follow the steps below to configure your search model:
|
||||||
|
|
||||||
|
1. Open the [SearchModelConfig](http://localhost:42110/server/admin/database/searchmodelconfig/) page on your Khoj admin panel.
|
||||||
|
2. Hit the Plus button to add a new model config or click the id of an existing model config to edit it.
|
||||||
|
3. Set the `biencoder` field to the name of the bi-encoder model supported [locally](https://huggingface.co/models?library=sentence-transformers) or via the API you configure.
|
||||||
|
4. Set the `Embeddings inference endpoint api key` to your OpenAI API key and `Embeddings inference endpoint type` to `OpenAI` to use an OpenAI embedding model.
|
||||||
|
5. Also set the `Embeddings inference endpoint` to your Azure OpenAI or OpenAI compatible API URL to use the model via those APIs.
|
||||||
|
6. Ensure the search model config you want to use is the **only one** that has `name` field set to `default`[^1].
|
||||||
|
7. Save the search model configs and restart your Khoj server to start using your new, updated search config.
|
||||||
|
|
||||||
|
:::info
|
||||||
|
You will need to re-index all your documents if you want to use a different bi-encoder model.
|
||||||
|
:::
|
||||||
|
|
||||||
|
:::info
|
||||||
|
You may need to tune the `Bi encoder confidence threshold` field for each bi-encoder to get appropriate number of documents for chat with your Knowledge base.
|
||||||
|
|
||||||
|
Confidence here is a normalized measure of semantic distance between your query and documents. The confidence threshold limits the documents returned to chat that fall within the distance specified in this field. It can take values between 0.0 (exact overlap) and 1.0 (no meaning overlap).
|
||||||
|
:::
|
||||||
|
|
||||||
|
[^1]: Khoj uses the first search model config named `default` it finds on startup as the search model config for that session
|
||||||
|
|||||||
@@ -249,6 +249,7 @@ def configure_server(
|
|||||||
model.bi_encoder,
|
model.bi_encoder,
|
||||||
model.embeddings_inference_endpoint,
|
model.embeddings_inference_endpoint,
|
||||||
model.embeddings_inference_endpoint_api_key,
|
model.embeddings_inference_endpoint_api_key,
|
||||||
|
model.embeddings_inference_endpoint_type,
|
||||||
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||||
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||||
model_kwargs=model.bi_encoder_model_config,
|
model_kwargs=model.bi_encoder_model_config,
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ from typing import List
|
|||||||
|
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.models import Count, Q
|
from django.db.models import Q
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from khoj.database.adapters import get_default_search_model
|
from khoj.database.adapters import get_default_search_model
|
||||||
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig
|
from khoj.database.models import Entry, SearchModelConfig
|
||||||
from khoj.processor.embeddings import EmbeddingsModel
|
from khoj.processor.embeddings import EmbeddingsModel
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -74,6 +74,7 @@ class Command(BaseCommand):
|
|||||||
model.bi_encoder,
|
model.bi_encoder,
|
||||||
model.embeddings_inference_endpoint,
|
model.embeddings_inference_endpoint,
|
||||||
model.embeddings_inference_endpoint_api_key,
|
model.embeddings_inference_endpoint_api_key,
|
||||||
|
model.embeddings_inference_endpoint_type,
|
||||||
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||||
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||||
model_kwargs=model.bi_encoder_model_config,
|
model_kwargs=model.bi_encoder_model_config,
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
# Generated by Django 5.0.10 on 2025-01-08 15:09
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
def set_endpoint_type(apps, schema_editor):
|
||||||
|
SearchModelConfig = apps.get_model("database", "SearchModelConfig")
|
||||||
|
SearchModelConfig.objects.filter(embeddings_inference_endpoint__isnull=False).exclude(
|
||||||
|
embeddings_inference_endpoint=""
|
||||||
|
).update(embeddings_inference_endpoint_type="huggingface")
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0078_khojuser_email_verification_code_expiry"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="searchmodelconfig",
|
||||||
|
name="embeddings_inference_endpoint_type",
|
||||||
|
field=models.CharField(
|
||||||
|
choices=[("huggingface", "Huggingface"), ("openai", "Openai"), ("local", "Local")],
|
||||||
|
default="local",
|
||||||
|
max_length=200,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.RunPython(set_endpoint_type, reverse_code=migrations.RunPython.noop),
|
||||||
|
]
|
||||||
@@ -481,6 +481,11 @@ class SearchModelConfig(DbBaseModel):
|
|||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
|
|
||||||
|
class ApiType(models.TextChoices):
|
||||||
|
HUGGINGFACE = "huggingface"
|
||||||
|
OPENAI = "openai"
|
||||||
|
LOCAL = "local"
|
||||||
|
|
||||||
# This is the model name exposed to users on their settings page
|
# This is the model name exposed to users on their settings page
|
||||||
name = models.CharField(max_length=200, default="default")
|
name = models.CharField(max_length=200, default="default")
|
||||||
# Type of content the model can generate embeddings for
|
# Type of content the model can generate embeddings for
|
||||||
@@ -501,6 +506,10 @@ class SearchModelConfig(DbBaseModel):
|
|||||||
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
|
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
|
||||||
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
# Inference server API type to use for embeddings inference.
|
||||||
|
embeddings_inference_endpoint_type = models.CharField(
|
||||||
|
max_length=200, choices=ApiType.choices, default=ApiType.LOCAL
|
||||||
|
)
|
||||||
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
|
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
|
||||||
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
|
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import openai
|
||||||
import requests
|
import requests
|
||||||
import tqdm
|
import tqdm
|
||||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||||
@@ -13,7 +15,14 @@ from tenacity import (
|
|||||||
)
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
|
from khoj.database.models import SearchModelConfig
|
||||||
|
from khoj.utils.helpers import (
|
||||||
|
fix_json_dict,
|
||||||
|
get_device,
|
||||||
|
get_openai_client,
|
||||||
|
merge_dicts,
|
||||||
|
timer,
|
||||||
|
)
|
||||||
from khoj.utils.rawconfig import SearchResponse
|
from khoj.utils.rawconfig import SearchResponse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -25,6 +34,7 @@ class EmbeddingsModel:
|
|||||||
model_name: str = "thenlper/gte-small",
|
model_name: str = "thenlper/gte-small",
|
||||||
embeddings_inference_endpoint: str = None,
|
embeddings_inference_endpoint: str = None,
|
||||||
embeddings_inference_endpoint_api_key: str = None,
|
embeddings_inference_endpoint_api_key: str = None,
|
||||||
|
embeddings_inference_endpoint_type=SearchModelConfig.ApiType.LOCAL,
|
||||||
query_encode_kwargs: dict = {},
|
query_encode_kwargs: dict = {},
|
||||||
docs_encode_kwargs: dict = {},
|
docs_encode_kwargs: dict = {},
|
||||||
model_kwargs: dict = {},
|
model_kwargs: dict = {},
|
||||||
@@ -37,15 +47,16 @@ class EmbeddingsModel:
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.inference_endpoint = embeddings_inference_endpoint
|
self.inference_endpoint = embeddings_inference_endpoint
|
||||||
self.api_key = embeddings_inference_endpoint_api_key
|
self.api_key = embeddings_inference_endpoint_api_key
|
||||||
with timer(f"Loaded embedding model {self.model_name}", logger):
|
self.inference_endpoint_type = embeddings_inference_endpoint_type
|
||||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
|
||||||
|
with timer(f"Loaded embedding model {self.model_name}", logger):
|
||||||
def inference_server_enabled(self) -> bool:
|
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||||
return self.api_key is not None and self.inference_endpoint is not None
|
|
||||||
|
|
||||||
def embed_query(self, query):
|
def embed_query(self, query):
|
||||||
if self.inference_server_enabled():
|
if self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
|
||||||
return self.embed_with_api([query])[0]
|
return self.embed_with_hf([query])[0]
|
||||||
|
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
|
||||||
|
return self.embed_with_openai([query])[0]
|
||||||
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
|
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
@@ -54,7 +65,7 @@ class EmbeddingsModel:
|
|||||||
stop=stop_after_attempt(5),
|
stop=stop_after_attempt(5),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
)
|
)
|
||||||
def embed_with_api(self, docs):
|
def embed_with_hf(self, docs):
|
||||||
payload = {"inputs": docs}
|
payload = {"inputs": docs}
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
@@ -71,23 +82,38 @@ class EmbeddingsModel:
|
|||||||
raise e
|
raise e
|
||||||
return response.json()["embeddings"]
|
return response.json()["embeddings"]
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_exception_type(requests.exceptions.HTTPError),
|
||||||
|
wait=wait_random_exponential(multiplier=1, max=10),
|
||||||
|
stop=stop_after_attempt(5),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
)
|
||||||
|
def embed_with_openai(self, docs):
|
||||||
|
client = get_openai_client(self.api_key, self.inference_endpoint)
|
||||||
|
response = client.embeddings.create(input=docs, model=self.model_name, encoding_format="float")
|
||||||
|
return [item.embedding for item in response.data]
|
||||||
|
|
||||||
def embed_documents(self, docs):
|
def embed_documents(self, docs):
|
||||||
if self.inference_server_enabled():
|
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
|
||||||
if "huggingface" not in self.inference_endpoint:
|
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
|
||||||
logger.warning(
|
elif self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
|
||||||
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
|
embed_with_api = self.embed_with_hf
|
||||||
)
|
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
|
||||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
|
embed_with_api = self.embed_with_openai
|
||||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
else:
|
||||||
embeddings = []
|
logger.warning(
|
||||||
with tqdm.tqdm(total=len(docs)) as pbar:
|
f"Unsupported inference endpoint: {self.inference_endpoint_type}. Generating embeddings locally instead."
|
||||||
for i in range(0, len(docs), 1000):
|
)
|
||||||
docs_to_embed = docs[i : i + 1000]
|
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
|
||||||
generated_embeddings = self.embed_with_api(docs_to_embed)
|
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||||
embeddings += generated_embeddings
|
embeddings = []
|
||||||
pbar.update(1000)
|
with tqdm.tqdm(total=len(docs)) as pbar:
|
||||||
return embeddings
|
for i in range(0, len(docs), 1000):
|
||||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
|
docs_to_embed = docs[i : i + 1000]
|
||||||
|
generated_embeddings = embed_with_api(docs_to_embed)
|
||||||
|
embeddings += generated_embeddings
|
||||||
|
pbar.update(1000)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class CrossEncoderModel:
|
class CrossEncoderModel:
|
||||||
|
|||||||
Reference in New Issue
Block a user