diff --git a/documentation/docusaurus.config.js b/documentation/docusaurus.config.js index b808c332..3a5afb6b 100644 --- a/documentation/docusaurus.config.js +++ b/documentation/docusaurus.config.js @@ -175,6 +175,12 @@ const config = { theme: prismThemes.github, darkTheme: prismThemes.dracula, }, + algolia: { + appId: "NBR0FXJNGW", + apiKey: "8841b34192a28b2d06f04dd28d768017", + indexName: "khoj", + contextualSearch: false, + } }), }; diff --git a/pyproject.toml b/pyproject.toml index caf3e410..e77731ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,8 +62,8 @@ dependencies = [ "pymupdf >= 1.23.5", "django == 4.2.7", "authlib == 1.2.1", - "gpt4all >= 2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'", - "gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'", + "gpt4all >= 2.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'", + "gpt4all >= 2.1.0; platform_system == 'Windows' or platform_system == 'Darwin'", "itsdangerous == 2.1.2", "httpx == 0.25.0", "pgvector == 0.2.4", diff --git a/src/interface/desktop/main.js b/src/interface/desktop/main.js index 4bb087d9..ec3e6fa4 100644 --- a/src/interface/desktop/main.js +++ b/src/interface/desktop/main.js @@ -208,7 +208,10 @@ function pushDataToKhoj (regenerate = false) { }) .catch(error => { console.error(error); - if (error.response.status == 429) { + if (error.code == 'ECONNREFUSED') { + const win = BrowserWindow.getAllWindows()[0]; + if (win) win.webContents.send('update-state', state); + } else if (error.response.status == 429) { const win = BrowserWindow.getAllWindows()[0]; if (win) win.webContents.send('needsSubscription', true); if (win) win.webContents.send('update-state', state); diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 786eccb3..dfc0fe4f 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -147,7 +147,15 @@ def configure_server( state.cross_encoder_model = dict() for model in search_models: - state.embeddings_model.update({model.name: EmbeddingsModel(model.bi_encoder)}) + state.embeddings_model.update( + { + model.name: EmbeddingsModel( + model.bi_encoder, + model.embeddings_inference_endpoint, + model.embeddings_inference_endpoint_api_key, + ) + } + ) state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)}) state.SearchType = configure_search_types() diff --git a/src/khoj/database/migrations/0025_searchmodelconfig_embeddings_inference_endpoint_and_more.py b/src/khoj/database/migrations/0025_searchmodelconfig_embeddings_inference_endpoint_and_more.py new file mode 100644 index 00000000..ef79e223 --- /dev/null +++ b/src/khoj/database/migrations/0025_searchmodelconfig_embeddings_inference_endpoint_and_more.py @@ -0,0 +1,22 @@ +# Generated by Django 4.2.7 on 2024-01-15 18:12 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0024_alter_entry_embeddings"), + ] + + operations = [ + migrations.AddField( + model_name="searchmodelconfig", + name="embeddings_inference_endpoint", + field=models.CharField(blank=True, default=None, max_length=200, null=True), + ), + migrations.AddField( + model_name="searchmodelconfig", + name="embeddings_inference_endpoint_api_key", + field=models.CharField(blank=True, default=None, max_length=200, null=True), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 873e9628..2b8887f7 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -110,6 +110,8 @@ class SearchModelConfig(BaseModel): model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT) bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") + embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) + embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) class TextToImageModelConfig(BaseModel): diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 23a77bb2..3e0f5380 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -123,9 +123,9 @@ def filter_questions(questions: List[str]): def converse_offline( - references, - online_results, user_query, + references=[], + online_results=[], conversation_log={}, model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf", loaded_model: Union[Any, None] = None, diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index 3a1862f7..9a2223c6 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -21,9 +21,11 @@ def download_model(model_name: str): # Try load chat model to GPU if: # 1. Loading chat model to GPU isn't disabled via CLI and # 2. Machine has GPU - # 3. GPU has enough free memory to load the chat model + # 3. GPU has enough free memory to load the chat model with max context length of 4096 device = ( - "gpu" if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"]) else "cpu" + "gpu" + if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"], 4096) + else "cpu" ) except ValueError: device = "cpu" @@ -35,7 +37,7 @@ def download_model(model_name: str): raise e # Now load the downloaded chat model onto appropriate device - chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False) + chat_model = gpt4all.GPT4All(model_name=model_name, n_ctx=4096, device=device, allow_download=False) logger.debug(f"Loaded chat model to {device.upper()}.") return chat_model diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index 4cb01823..c0e91ce4 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -1,23 +1,69 @@ +import logging from typing import List +import requests +import tqdm from sentence_transformers import CrossEncoder, SentenceTransformer from torch import nn from khoj.utils.helpers import get_device from khoj.utils.rawconfig import SearchResponse +logger = logging.getLogger(__name__) + class EmbeddingsModel: - def __init__(self, model_name: str = "thenlper/gte-small"): + def __init__( + self, + model_name: str = "thenlper/gte-small", + embeddings_inference_endpoint: str = None, + embeddings_inference_endpoint_api_key: str = None, + ): self.encode_kwargs = {"normalize_embeddings": True} self.model_kwargs = {"device": get_device()} self.model_name = model_name + self.inference_endpoint = embeddings_inference_endpoint + self.api_key = embeddings_inference_endpoint_api_key self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) def embed_query(self, query): + if self.api_key is not None and self.inference_endpoint is not None: + target_url = f"{self.inference_endpoint}" + payload = {"inputs": [query]} + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + response = requests.post(target_url, json=payload, headers=headers) + return response.json()["embeddings"][0] return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] def embed_documents(self, docs): + if self.api_key is not None and self.inference_endpoint is not None: + target_url = f"{self.inference_endpoint}" + if "huggingface" not in target_url: + logger.warning( + f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint." + ) + return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() + # break up the docs payload in chunks of 1000 to avoid hitting rate limits + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + with tqdm.tqdm(total=len(docs)) as pbar: + for i in range(0, len(docs), 1000): + payload = {"inputs": docs[i : i + 1000]} + response = requests.post(target_url, json=payload, headers=headers) + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + print(f"Error: {e}") + print(f"Response: {response.json()}") + raise e + if i == 0: + embeddings = response.json()["embeddings"] + else: + embeddings += response.json()["embeddings"] + pbar.update(1000) + return embeddings return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index fb125f1d..a6748d4f 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -6,6 +6,7 @@ import os import time import uuid from typing import Any, Dict, List, Optional, Union +from urllib.parse import unquote from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile @@ -362,6 +363,7 @@ async def chat( rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), ) -> Response: user: KhojUser = request.user.object + q = unquote(q) await is_ready_to_chat(user) conversation_command = get_conversation_command(query=q, any_references=True) diff --git a/tests/test_gpt4all_chat_director.py b/tests/test_gpt4all_chat_director.py index 28bc3a8f..0173ff7b 100644 --- a/tests/test_gpt4all_chat_director.py +++ b/tests/test_gpt4all_chat_director.py @@ -1,3 +1,4 @@ +import os import urllib.parse from urllib.parse import quote @@ -53,6 +54,7 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c # ---------------------------------------------------------------------------------------------------- +@pytest.mark.skipif(os.getenv("SERPER_DEV_API_KEY") is None, reason="requires SERPER_DEV_API_KEY") @pytest.mark.chatquality @pytest.mark.django_db(transaction=True) def test_chat_with_online_content(chat_client):