Merge branch 'master' into short-circuit-api-rate-limiter

This commit is contained in:
Debanjum Singh Solanky
2024-01-16 18:18:34 +05:30
11 changed files with 103 additions and 10 deletions

View File

@@ -175,6 +175,12 @@ const config = {
theme: prismThemes.github, theme: prismThemes.github,
darkTheme: prismThemes.dracula, darkTheme: prismThemes.dracula,
}, },
algolia: {
appId: "NBR0FXJNGW",
apiKey: "8841b34192a28b2d06f04dd28d768017",
indexName: "khoj",
contextualSearch: false,
}
}), }),
}; };

View File

@@ -62,8 +62,8 @@ dependencies = [
"pymupdf >= 1.23.5", "pymupdf >= 1.23.5",
"django == 4.2.7", "django == 4.2.7",
"authlib == 1.2.1", "authlib == 1.2.1",
"gpt4all >= 2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'", "gpt4all >= 2.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
"gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'", "gpt4all >= 2.1.0; platform_system == 'Windows' or platform_system == 'Darwin'",
"itsdangerous == 2.1.2", "itsdangerous == 2.1.2",
"httpx == 0.25.0", "httpx == 0.25.0",
"pgvector == 0.2.4", "pgvector == 0.2.4",

View File

@@ -208,7 +208,10 @@ function pushDataToKhoj (regenerate = false) {
}) })
.catch(error => { .catch(error => {
console.error(error); console.error(error);
if (error.response.status == 429) { if (error.code == 'ECONNREFUSED') {
const win = BrowserWindow.getAllWindows()[0];
if (win) win.webContents.send('update-state', state);
} else if (error.response.status == 429) {
const win = BrowserWindow.getAllWindows()[0]; const win = BrowserWindow.getAllWindows()[0];
if (win) win.webContents.send('needsSubscription', true); if (win) win.webContents.send('needsSubscription', true);
if (win) win.webContents.send('update-state', state); if (win) win.webContents.send('update-state', state);

View File

@@ -147,7 +147,15 @@ def configure_server(
state.cross_encoder_model = dict() state.cross_encoder_model = dict()
for model in search_models: for model in search_models:
state.embeddings_model.update({model.name: EmbeddingsModel(model.bi_encoder)}) state.embeddings_model.update(
{
model.name: EmbeddingsModel(
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
)
}
)
state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)}) state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)})
state.SearchType = configure_search_types() state.SearchType = configure_search_types()

View File

@@ -0,0 +1,22 @@
# Generated by Django 4.2.7 on 2024-01-15 18:12
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0024_alter_entry_embeddings"),
]
operations = [
migrations.AddField(
model_name="searchmodelconfig",
name="embeddings_inference_endpoint",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="searchmodelconfig",
name="embeddings_inference_endpoint_api_key",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
]

View File

@@ -110,6 +110,8 @@ class SearchModelConfig(BaseModel):
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
class TextToImageModelConfig(BaseModel): class TextToImageModelConfig(BaseModel):

View File

@@ -123,9 +123,9 @@ def filter_questions(questions: List[str]):
def converse_offline( def converse_offline(
references,
online_results,
user_query, user_query,
references=[],
online_results=[],
conversation_log={}, conversation_log={},
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf", model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
loaded_model: Union[Any, None] = None, loaded_model: Union[Any, None] = None,

View File

@@ -21,9 +21,11 @@ def download_model(model_name: str):
# Try load chat model to GPU if: # Try load chat model to GPU if:
# 1. Loading chat model to GPU isn't disabled via CLI and # 1. Loading chat model to GPU isn't disabled via CLI and
# 2. Machine has GPU # 2. Machine has GPU
# 3. GPU has enough free memory to load the chat model # 3. GPU has enough free memory to load the chat model with max context length of 4096
device = ( device = (
"gpu" if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"]) else "cpu" "gpu"
if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"], 4096)
else "cpu"
) )
except ValueError: except ValueError:
device = "cpu" device = "cpu"
@@ -35,7 +37,7 @@ def download_model(model_name: str):
raise e raise e
# Now load the downloaded chat model onto appropriate device # Now load the downloaded chat model onto appropriate device
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False) chat_model = gpt4all.GPT4All(model_name=model_name, n_ctx=4096, device=device, allow_download=False)
logger.debug(f"Loaded chat model to {device.upper()}.") logger.debug(f"Loaded chat model to {device.upper()}.")
return chat_model return chat_model

View File

@@ -1,23 +1,69 @@
import logging
from typing import List from typing import List
import requests
import tqdm
from sentence_transformers import CrossEncoder, SentenceTransformer from sentence_transformers import CrossEncoder, SentenceTransformer
from torch import nn from torch import nn
from khoj.utils.helpers import get_device from khoj.utils.helpers import get_device
from khoj.utils.rawconfig import SearchResponse from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__)
class EmbeddingsModel: class EmbeddingsModel:
def __init__(self, model_name: str = "thenlper/gte-small"): def __init__(
self,
model_name: str = "thenlper/gte-small",
embeddings_inference_endpoint: str = None,
embeddings_inference_endpoint_api_key: str = None,
):
self.encode_kwargs = {"normalize_embeddings": True} self.encode_kwargs = {"normalize_embeddings": True}
self.model_kwargs = {"device": get_device()} self.model_kwargs = {"device": get_device()}
self.model_name = model_name self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def embed_query(self, query): def embed_query(self, query):
if self.api_key is not None and self.inference_endpoint is not None:
target_url = f"{self.inference_endpoint}"
payload = {"inputs": [query]}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.post(target_url, json=payload, headers=headers)
return response.json()["embeddings"][0]
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
def embed_documents(self, docs): def embed_documents(self, docs):
if self.api_key is not None and self.inference_endpoint is not None:
target_url = f"{self.inference_endpoint}"
if "huggingface" not in target_url:
logger.warning(
f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint."
)
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
with tqdm.tqdm(total=len(docs)) as pbar:
for i in range(0, len(docs), 1000):
payload = {"inputs": docs[i : i + 1000]}
response = requests.post(target_url, json=payload, headers=headers)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
print(f"Error: {e}")
print(f"Response: {response.json()}")
raise e
if i == 0:
embeddings = response.json()["embeddings"]
else:
embeddings += response.json()["embeddings"]
pbar.update(1000)
return embeddings
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()

View File

@@ -6,6 +6,7 @@ import os
import time import time
import uuid import uuid
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from urllib.parse import unquote
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
@@ -362,6 +363,7 @@ async def chat(
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response: ) -> Response:
user: KhojUser = request.user.object user: KhojUser = request.user.object
q = unquote(q)
await is_ready_to_chat(user) await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True) conversation_command = get_conversation_command(query=q, any_references=True)

View File

@@ -1,3 +1,4 @@
import os
import urllib.parse import urllib.parse
from urllib.parse import quote from urllib.parse import quote
@@ -53,6 +54,7 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("SERPER_DEV_API_KEY") is None, reason="requires SERPER_DEV_API_KEY")
@pytest.mark.chatquality @pytest.mark.chatquality
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_chat_with_online_content(chat_client): def test_chat_with_online_content(chat_client):