mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Merge branch 'master' into short-circuit-api-rate-limiter
This commit is contained in:
@@ -175,6 +175,12 @@ const config = {
|
|||||||
theme: prismThemes.github,
|
theme: prismThemes.github,
|
||||||
darkTheme: prismThemes.dracula,
|
darkTheme: prismThemes.dracula,
|
||||||
},
|
},
|
||||||
|
algolia: {
|
||||||
|
appId: "NBR0FXJNGW",
|
||||||
|
apiKey: "8841b34192a28b2d06f04dd28d768017",
|
||||||
|
indexName: "khoj",
|
||||||
|
contextualSearch: false,
|
||||||
|
}
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -62,8 +62,8 @@ dependencies = [
|
|||||||
"pymupdf >= 1.23.5",
|
"pymupdf >= 1.23.5",
|
||||||
"django == 4.2.7",
|
"django == 4.2.7",
|
||||||
"authlib == 1.2.1",
|
"authlib == 1.2.1",
|
||||||
"gpt4all >= 2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
"gpt4all >= 2.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||||
"gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'",
|
"gpt4all >= 2.1.0; platform_system == 'Windows' or platform_system == 'Darwin'",
|
||||||
"itsdangerous == 2.1.2",
|
"itsdangerous == 2.1.2",
|
||||||
"httpx == 0.25.0",
|
"httpx == 0.25.0",
|
||||||
"pgvector == 0.2.4",
|
"pgvector == 0.2.4",
|
||||||
|
|||||||
@@ -208,7 +208,10 @@ function pushDataToKhoj (regenerate = false) {
|
|||||||
})
|
})
|
||||||
.catch(error => {
|
.catch(error => {
|
||||||
console.error(error);
|
console.error(error);
|
||||||
if (error.response.status == 429) {
|
if (error.code == 'ECONNREFUSED') {
|
||||||
|
const win = BrowserWindow.getAllWindows()[0];
|
||||||
|
if (win) win.webContents.send('update-state', state);
|
||||||
|
} else if (error.response.status == 429) {
|
||||||
const win = BrowserWindow.getAllWindows()[0];
|
const win = BrowserWindow.getAllWindows()[0];
|
||||||
if (win) win.webContents.send('needsSubscription', true);
|
if (win) win.webContents.send('needsSubscription', true);
|
||||||
if (win) win.webContents.send('update-state', state);
|
if (win) win.webContents.send('update-state', state);
|
||||||
|
|||||||
@@ -147,7 +147,15 @@ def configure_server(
|
|||||||
state.cross_encoder_model = dict()
|
state.cross_encoder_model = dict()
|
||||||
|
|
||||||
for model in search_models:
|
for model in search_models:
|
||||||
state.embeddings_model.update({model.name: EmbeddingsModel(model.bi_encoder)})
|
state.embeddings_model.update(
|
||||||
|
{
|
||||||
|
model.name: EmbeddingsModel(
|
||||||
|
model.bi_encoder,
|
||||||
|
model.embeddings_inference_endpoint,
|
||||||
|
model.embeddings_inference_endpoint_api_key,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)})
|
state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)})
|
||||||
|
|
||||||
state.SearchType = configure_search_types()
|
state.SearchType = configure_search_types()
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
# Generated by Django 4.2.7 on 2024-01-15 18:12
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0024_alter_entry_embeddings"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="searchmodelconfig",
|
||||||
|
name="embeddings_inference_endpoint",
|
||||||
|
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="searchmodelconfig",
|
||||||
|
name="embeddings_inference_endpoint_api_key",
|
||||||
|
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -110,6 +110,8 @@ class SearchModelConfig(BaseModel):
|
|||||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
|
||||||
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
|
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
|
||||||
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||||
|
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
class TextToImageModelConfig(BaseModel):
|
class TextToImageModelConfig(BaseModel):
|
||||||
|
|||||||
@@ -123,9 +123,9 @@ def filter_questions(questions: List[str]):
|
|||||||
|
|
||||||
|
|
||||||
def converse_offline(
|
def converse_offline(
|
||||||
references,
|
|
||||||
online_results,
|
|
||||||
user_query,
|
user_query,
|
||||||
|
references=[],
|
||||||
|
online_results=[],
|
||||||
conversation_log={},
|
conversation_log={},
|
||||||
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
|
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
|
||||||
loaded_model: Union[Any, None] = None,
|
loaded_model: Union[Any, None] = None,
|
||||||
|
|||||||
@@ -21,9 +21,11 @@ def download_model(model_name: str):
|
|||||||
# Try load chat model to GPU if:
|
# Try load chat model to GPU if:
|
||||||
# 1. Loading chat model to GPU isn't disabled via CLI and
|
# 1. Loading chat model to GPU isn't disabled via CLI and
|
||||||
# 2. Machine has GPU
|
# 2. Machine has GPU
|
||||||
# 3. GPU has enough free memory to load the chat model
|
# 3. GPU has enough free memory to load the chat model with max context length of 4096
|
||||||
device = (
|
device = (
|
||||||
"gpu" if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"]) else "cpu"
|
"gpu"
|
||||||
|
if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"], 4096)
|
||||||
|
else "cpu"
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
@@ -35,7 +37,7 @@ def download_model(model_name: str):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
# Now load the downloaded chat model onto appropriate device
|
# Now load the downloaded chat model onto appropriate device
|
||||||
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)
|
chat_model = gpt4all.GPT4All(model_name=model_name, n_ctx=4096, device=device, allow_download=False)
|
||||||
logger.debug(f"Loaded chat model to {device.upper()}.")
|
logger.debug(f"Loaded chat model to {device.upper()}.")
|
||||||
|
|
||||||
return chat_model
|
return chat_model
|
||||||
|
|||||||
@@ -1,23 +1,69 @@
|
|||||||
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import tqdm
|
||||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from khoj.utils.helpers import get_device
|
from khoj.utils.helpers import get_device
|
||||||
from khoj.utils.rawconfig import SearchResponse
|
from khoj.utils.rawconfig import SearchResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsModel:
|
class EmbeddingsModel:
|
||||||
def __init__(self, model_name: str = "thenlper/gte-small"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "thenlper/gte-small",
|
||||||
|
embeddings_inference_endpoint: str = None,
|
||||||
|
embeddings_inference_endpoint_api_key: str = None,
|
||||||
|
):
|
||||||
self.encode_kwargs = {"normalize_embeddings": True}
|
self.encode_kwargs = {"normalize_embeddings": True}
|
||||||
self.model_kwargs = {"device": get_device()}
|
self.model_kwargs = {"device": get_device()}
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
self.inference_endpoint = embeddings_inference_endpoint
|
||||||
|
self.api_key = embeddings_inference_endpoint_api_key
|
||||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||||
|
|
||||||
def embed_query(self, query):
|
def embed_query(self, query):
|
||||||
|
if self.api_key is not None and self.inference_endpoint is not None:
|
||||||
|
target_url = f"{self.inference_endpoint}"
|
||||||
|
payload = {"inputs": [query]}
|
||||||
|
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||||
|
response = requests.post(target_url, json=payload, headers=headers)
|
||||||
|
return response.json()["embeddings"][0]
|
||||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||||
|
|
||||||
def embed_documents(self, docs):
|
def embed_documents(self, docs):
|
||||||
|
if self.api_key is not None and self.inference_endpoint is not None:
|
||||||
|
target_url = f"{self.inference_endpoint}"
|
||||||
|
if "huggingface" not in target_url:
|
||||||
|
logger.warning(
|
||||||
|
f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint."
|
||||||
|
)
|
||||||
|
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||||
|
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
with tqdm.tqdm(total=len(docs)) as pbar:
|
||||||
|
for i in range(0, len(docs), 1000):
|
||||||
|
payload = {"inputs": docs[i : i + 1000]}
|
||||||
|
response = requests.post(target_url, json=payload, headers=headers)
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
print(f"Response: {response.json()}")
|
||||||
|
raise e
|
||||||
|
if i == 0:
|
||||||
|
embeddings = response.json()["embeddings"]
|
||||||
|
else:
|
||||||
|
embeddings += response.json()["embeddings"]
|
||||||
|
pbar.update(1000)
|
||||||
|
return embeddings
|
||||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||||
@@ -362,6 +363,7 @@ async def chat(
|
|||||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
user: KhojUser = request.user.object
|
user: KhojUser = request.user.object
|
||||||
|
q = unquote(q)
|
||||||
|
|
||||||
await is_ready_to_chat(user)
|
await is_ready_to_chat(user)
|
||||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
@@ -53,6 +54,7 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c
|
|||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.skipif(os.getenv("SERPER_DEV_API_KEY") is None, reason="requires SERPER_DEV_API_KEY")
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_chat_with_online_content(chat_client):
|
def test_chat_with_online_content(chat_client):
|
||||||
|
|||||||
Reference in New Issue
Block a user