mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Support using Embeddings Model exposed via OpenAI (compatible) API
This commit is contained in:
@@ -249,6 +249,7 @@ def configure_server(
|
||||
model.bi_encoder,
|
||||
model.embeddings_inference_endpoint,
|
||||
model.embeddings_inference_endpoint_api_key,
|
||||
model.embeddings_inference_endpoint_type,
|
||||
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||
model_kwargs=model.bi_encoder_model_config,
|
||||
|
||||
@@ -3,11 +3,11 @@ from typing import List
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import transaction
|
||||
from django.db.models import Count, Q
|
||||
from django.db.models import Q
|
||||
from tqdm import tqdm
|
||||
|
||||
from khoj.database.adapters import get_default_search_model
|
||||
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig
|
||||
from khoj.database.models import Entry, SearchModelConfig
|
||||
from khoj.processor.embeddings import EmbeddingsModel
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -74,6 +74,7 @@ class Command(BaseCommand):
|
||||
model.bi_encoder,
|
||||
model.embeddings_inference_endpoint,
|
||||
model.embeddings_inference_endpoint_api_key,
|
||||
model.embeddings_inference_endpoint_type,
|
||||
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||
model_kwargs=model.bi_encoder_model_config,
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
# Generated by Django 5.0.10 on 2025-01-08 15:09
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
def set_endpoint_type(apps, schema_editor):
|
||||
SearchModelConfig = apps.get_model("database", "SearchModelConfig")
|
||||
SearchModelConfig.objects.filter(embeddings_inference_endpoint__isnull=False).exclude(
|
||||
embeddings_inference_endpoint=""
|
||||
).update(embeddings_inference_endpoint_type="huggingface")
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0078_khojuser_email_verification_code_expiry"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="searchmodelconfig",
|
||||
name="embeddings_inference_endpoint_type",
|
||||
field=models.CharField(
|
||||
choices=[("huggingface", "Huggingface"), ("openai", "Openai"), ("local", "Local")],
|
||||
default="local",
|
||||
max_length=200,
|
||||
),
|
||||
),
|
||||
migrations.RunPython(set_endpoint_type, reverse_code=migrations.RunPython.noop),
|
||||
]
|
||||
@@ -481,6 +481,11 @@ class SearchModelConfig(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
TEXT = "text"
|
||||
|
||||
class ApiType(models.TextChoices):
|
||||
HUGGINGFACE = "huggingface"
|
||||
OPENAI = "openai"
|
||||
LOCAL = "local"
|
||||
|
||||
# This is the model name exposed to users on their settings page
|
||||
name = models.CharField(max_length=200, default="default")
|
||||
# Type of content the model can generate embeddings for
|
||||
@@ -501,6 +506,10 @@ class SearchModelConfig(DbBaseModel):
|
||||
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
# Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server
|
||||
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
# Inference server API type to use for embeddings inference.
|
||||
embeddings_inference_endpoint_type = models.CharField(
|
||||
max_length=200, choices=ApiType.choices, default=ApiType.LOCAL
|
||||
)
|
||||
# Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server
|
||||
cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
# Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import openai
|
||||
import requests
|
||||
import tqdm
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||
@@ -13,7 +15,14 @@ from tenacity import (
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
|
||||
from khoj.database.models import SearchModelConfig
|
||||
from khoj.utils.helpers import (
|
||||
fix_json_dict,
|
||||
get_device,
|
||||
get_openai_client,
|
||||
merge_dicts,
|
||||
timer,
|
||||
)
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,6 +34,7 @@ class EmbeddingsModel:
|
||||
model_name: str = "thenlper/gte-small",
|
||||
embeddings_inference_endpoint: str = None,
|
||||
embeddings_inference_endpoint_api_key: str = None,
|
||||
embeddings_inference_endpoint_type=SearchModelConfig.ApiType.LOCAL,
|
||||
query_encode_kwargs: dict = {},
|
||||
docs_encode_kwargs: dict = {},
|
||||
model_kwargs: dict = {},
|
||||
@@ -37,15 +47,16 @@ class EmbeddingsModel:
|
||||
self.model_name = model_name
|
||||
self.inference_endpoint = embeddings_inference_endpoint
|
||||
self.api_key = embeddings_inference_endpoint_api_key
|
||||
with timer(f"Loaded embedding model {self.model_name}", logger):
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
|
||||
def inference_server_enabled(self) -> bool:
|
||||
return self.api_key is not None and self.inference_endpoint is not None
|
||||
self.inference_endpoint_type = embeddings_inference_endpoint_type
|
||||
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
|
||||
with timer(f"Loaded embedding model {self.model_name}", logger):
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
|
||||
def embed_query(self, query):
|
||||
if self.inference_server_enabled():
|
||||
return self.embed_with_api([query])[0]
|
||||
if self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
|
||||
return self.embed_with_hf([query])[0]
|
||||
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
|
||||
return self.embed_with_openai([query])[0]
|
||||
return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0]
|
||||
|
||||
@retry(
|
||||
@@ -54,7 +65,7 @@ class EmbeddingsModel:
|
||||
stop=stop_after_attempt(5),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
)
|
||||
def embed_with_api(self, docs):
|
||||
def embed_with_hf(self, docs):
|
||||
payload = {"inputs": docs}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
@@ -71,23 +82,38 @@ class EmbeddingsModel:
|
||||
raise e
|
||||
return response.json()["embeddings"]
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(requests.exceptions.HTTPError),
|
||||
wait=wait_random_exponential(multiplier=1, max=10),
|
||||
stop=stop_after_attempt(5),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
)
|
||||
def embed_with_openai(self, docs):
|
||||
client = get_openai_client(self.api_key, self.inference_endpoint)
|
||||
response = client.embeddings.create(input=docs, model=self.model_name, encoding_format="float")
|
||||
return [item.embedding for item in response.data]
|
||||
|
||||
def embed_documents(self, docs):
|
||||
if self.inference_server_enabled():
|
||||
if "huggingface" not in self.inference_endpoint:
|
||||
logger.warning(
|
||||
f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead."
|
||||
)
|
||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
|
||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||
embeddings = []
|
||||
with tqdm.tqdm(total=len(docs)) as pbar:
|
||||
for i in range(0, len(docs), 1000):
|
||||
docs_to_embed = docs[i : i + 1000]
|
||||
generated_embeddings = self.embed_with_api(docs_to_embed)
|
||||
embeddings += generated_embeddings
|
||||
pbar.update(1000)
|
||||
return embeddings
|
||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
|
||||
if self.inference_endpoint_type == SearchModelConfig.ApiType.LOCAL:
|
||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() if docs else []
|
||||
elif self.inference_endpoint_type == SearchModelConfig.ApiType.HUGGINGFACE:
|
||||
embed_with_api = self.embed_with_hf
|
||||
elif self.inference_endpoint_type == SearchModelConfig.ApiType.OPENAI:
|
||||
embed_with_api = self.embed_with_openai
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unsupported inference endpoint: {self.inference_endpoint_type}. Generating embeddings locally instead."
|
||||
)
|
||||
return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist()
|
||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||
embeddings = []
|
||||
with tqdm.tqdm(total=len(docs)) as pbar:
|
||||
for i in range(0, len(docs), 1000):
|
||||
docs_to_embed = docs[i : i + 1000]
|
||||
generated_embeddings = embed_with_api(docs_to_embed)
|
||||
embeddings += generated_embeddings
|
||||
pbar.update(1000)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CrossEncoderModel:
|
||||
|
||||
Reference in New Issue
Block a user