mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-08 05:39:13 +00:00
Support using Embeddings Model exposed via OpenAI (compatible) API
This commit is contained in:
@@ -249,6 +249,7 @@ def configure_server(
|
|||||||
model.bi_encoder,
|
model.bi_encoder,
|
||||||
model.embeddings_inference_endpoint,
|
model.embeddings_inference_endpoint,
|
||||||
model.embeddings_inference_endpoint_api_key,
|
model.embeddings_inference_endpoint_api_key,
|
||||||
|
model.embeddings_inference_endpoint_type,
|
||||||
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||||
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||||
model_kwargs=model.bi_encoder_model_config,
|
model_kwargs=model.bi_encoder_model_config,
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ from typing import List
|
|||||||
|
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.models import Count, Q
|
from django.db.models import Q
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from khoj.database.adapters import get_default_search_model
|
from khoj.database.adapters import get_default_search_model
|
||||||
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig
|
from khoj.database.models import Entry, SearchModelConfig
|
||||||
from khoj.processor.embeddings import EmbeddingsModel
|
from khoj.processor.embeddings import EmbeddingsModel
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -74,6 +74,7 @@ class Command(BaseCommand):
|
|||||||
model.bi_encoder,
|
model.bi_encoder,
|
||||||
model.embeddings_inference_endpoint,
|
model.embeddings_inference_endpoint,
|
||||||
model.embeddings_inference_endpoint_api_key,
|
model.embeddings_inference_endpoint_api_key,
|
||||||
|
model.embeddings_inference_endpoint_type,
|
||||||
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||||
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||||
model_kwargs=model.bi_encoder_model_config,
|
model_kwargs=model.bi_encoder_model_config,
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
# Generated by Django 5.0.10 on 2025-01-08 15:09
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
def set_endpoint_type(apps, schema_editor):
|
||||||
|
SearchModelConfig = apps.get_model("database", "SearchModelConfig")
|
||||||
|
SearchModelConfig.objects.filter(embeddings_inference_endpoint__isnull=False).exclude(
|
||||||
|
embeddings_inference_endpoint=""
|
||||||
|
).update(embeddings_inference_endpoint_type="huggingface")
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0078_khojuser_email_verification_code_expiry"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="searchmodelconfig",
|
||||||
|
name="embeddings_inference_endpoint_type",
|
||||||
|
field=models.CharField(
|
||||||
|
choices=[("huggingface", "Huggingface"), ("openai", "Openai"), ("local", "Local")],
|
||||||
|
default="local",
|
||||||
|
max_length=200,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.RunPython(set_endpoint_type, reverse_code=migrations.RunPython.noop),
|
||||||
|
]
|
||||||
@@ -481,6 +481,11 @@ class SearchModelConfig(DbBaseModel):
|
|||||||
class ModelType(models.TextChoices):
|
class ModelType(models.TextChoices):
|
||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
|
|
||||||
|
class ApiType(models.TextChoices):
|
||||||
|
HUGGINGFACE = "huggingface"
|
||||||
|
OPENAI = "openai"
|
||||||
|
LOCAL = "local"
|
||||||
|
|
||||||
# This is the model name exposed to users on their settings page
|
# This is the model name exposed to users on their settings page
|
||||||
name = models.CharField(max_length=200, default="default")
|
name = models.CharField(max_length=200, default="default")
|
||||||
# Type of content the model can generate embeddings for
|
# Type of content the model can generate embeddings for
|
||||||
@@ -501,6 +506,10 @@ class SearchModelConfig(DbBaseModel):
|
|||||||
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
|
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
|
||||||
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
# Inference server API type to use for embeddings inference.
|
||||||
|
embeddings_inference_endpoint_type = models.CharField(
|
||||||
|
max_length=200, choices=ApiType.choices, default=ApiType.LOCAL
|
||||||
|
)
|
||||||
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
|
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
|
||||||
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
|
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import openai
|
||||||
import requests
|
import requests
|
||||||
import tqdm
|
import tqdm
|
||||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||||
@@ -13,7 +15,14 @@ from tenacity import (
|
|||||||
)
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
|
from khoj.database.models import SearchModelConfig
|
||||||
|
from khoj.utils.helpers import (
|
||||||
|
fix_json_dict,
|
||||||
|
get_device,
|
||||||
|
get_openai_client,
|
||||||
|
merge_dicts,
|
||||||
|
timer,
|
||||||
|
)
|
||||||
from khoj.utils.rawconfig import SearchResponse
|
from khoj.utils.rawconfig import SearchResponse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -25,6 +34,7 @@ class EmbeddingsModel:
|
|||||||
model_name: str = "thenlper/gte-small",
|
model_name: str = "thenlper/gte-small",
|
||||||
embeddings_inference_endpoint: str = None,
|
embeddings_inference_endpoint: str = None,
|
||||||
embeddings_inference_endpoint_api_key: str = None,
|
embeddings_inference_endpoint_api_key: str = None,
|
||||||
|
embeddings_inference_endpoint_type=SearchModelConfig.ApiType.LOCAL,
|
||||||
query_encode_kwargs: dict = {},
|
query_encode_kwargs: dict = {},
|
||||||
docs_encode_kwargs: dict = {},
|
docs_encode_kwargs: dict = {},
|
||||||
model_kwargs: dict = {},
|
model_kwargs: dict = {},
|
||||||
@@ -37,15 +47,16 @@ class EmbeddingsModel:
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.inference_endpoint = embeddings_inference_endpoint
|
self.inference_endpoint = embeddings_inference_endpoint
|
||||||
self.api_key = embeddings_inference_endpoint_api_key
|
self.api_key = embeddings_inference_endpoint_api_key
|
||||||
with timer(f"Loaded embedding model {self.model_name}", logger):
|
self.inference_endpoint_type = embeddings_inference_endpoint_type
|
||||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
|
||||||
|
with timer(f"Loaded embedding model {self.model_name}", logger):
|
||||||
def inference_server_enabled(self) -> bool:
|
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||||
return self.api_key is not None and self.inference_endpoint is not None
|
|
||||||
|
|
||||||
def embed_query(self, query):
|
def embed_query(self, query):
|
||||||
if self.inference_server_enabled():
|
if self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
|
||||||
return self.embed_with_api([query])[0]
|
return self.embed_with_hf([query])[0]
|
||||||
|
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
|
||||||
|
return self.embed_with_openai([query])[0]
|
||||||
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
|
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
@@ -54,7 +65,7 @@ class EmbeddingsModel:
|
|||||||
stop=stop_after_attempt(5),
|
stop=stop_after_attempt(5),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
)
|
)
|
||||||
def embed_with_api(self, docs):
|
def embed_with_hf(self, docs):
|
||||||
payload = {"inputs": docs}
|
payload = {"inputs": docs}
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
@@ -71,23 +82,38 @@ class EmbeddingsModel:
|
|||||||
raise e
|
raise e
|
||||||
return response.json()["embeddings"]
|
return response.json()["embeddings"]
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_exception_type(requests.exceptions.HTTPError),
|
||||||
|
wait=wait_random_exponential(multiplier=1, max=10),
|
||||||
|
stop=stop_after_attempt(5),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
)
|
||||||
|
def embed_with_openai(self, docs):
|
||||||
|
client = get_openai_client(self.api_key, self.inference_endpoint)
|
||||||
|
response = client.embeddings.create(input=docs, model=self.model_name, encoding_format="float")
|
||||||
|
return [item.embedding for item in response.data]
|
||||||
|
|
||||||
def embed_documents(self, docs):
|
def embed_documents(self, docs):
|
||||||
if self.inference_server_enabled():
|
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
|
||||||
if "huggingface" not in self.inference_endpoint:
|
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
|
||||||
logger.warning(
|
elif self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
|
||||||
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
|
embed_with_api = self.embed_with_hf
|
||||||
)
|
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
|
||||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
|
embed_with_api = self.embed_with_openai
|
||||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
else:
|
||||||
embeddings = []
|
logger.warning(
|
||||||
with tqdm.tqdm(total=len(docs)) as pbar:
|
f"Unsupported inference endpoint: {self.inference_endpoint_type}. Generating embeddings locally instead."
|
||||||
for i in range(0, len(docs), 1000):
|
)
|
||||||
docs_to_embed = docs[i : i + 1000]
|
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
|
||||||
generated_embeddings = self.embed_with_api(docs_to_embed)
|
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||||
embeddings += generated_embeddings
|
embeddings = []
|
||||||
pbar.update(1000)
|
with tqdm.tqdm(total=len(docs)) as pbar:
|
||||||
return embeddings
|
for i in range(0, len(docs), 1000):
|
||||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
|
docs_to_embed = docs[i : i + 1000]
|
||||||
|
generated_embeddings = embed_with_api(docs_to_embed)
|
||||||
|
embeddings += generated_embeddings
|
||||||
|
pbar.update(1000)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class CrossEncoderModel:
|
class CrossEncoderModel:
|
||||||
|
|||||||
Reference in New Issue
Block a user