diff --git a/.github/workflows/dockerize_dev.yml b/.github/workflows/dockerize_dev.yml
new file mode 100644
index 00000000..1d037ce7
--- /dev/null
+++ b/.github/workflows/dockerize_dev.yml
@@ -0,0 +1,43 @@
+name: dockerize-dev
+
+on:
+ pull_request:
+ paths:
+ - src/khoj/**
+ - config/**
+ - pyproject.toml
+ - prod.Dockerfile
+ - .github/workflows/dockerize_dev.yml
+ workflow_dispatch:
+
+env:
+ DOCKER_IMAGE_TAG: 'dev'
+
+jobs:
+ build:
+ name: Build Production Docker Image, Push to Container Registry
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Code
+ uses: actions/checkout@v3
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v2
+
+ - name: Login to GitHub Container Registry
+ uses: docker/login-action@v2
+ with:
+ registry: ghcr.io
+ username: ${{ github.repository_owner }}
+ password: ${{ secrets.PAT }}
+
+ - name: 📦 Build and Push Docker Image
+ uses: docker/build-push-action@v2
+ with:
+ context: .
+ file: prod.Dockerfile
+ platforms: linux/amd64
+ push: true
+ tags: ghcr.io/${{ github.repository }}:${{ env.DOCKER_IMAGE_TAG }}
+ build-args: |
+ PORT=42110
diff --git a/pyproject.toml b/pyproject.toml
index 10e44ac0..15a4c8e2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -54,7 +54,7 @@ dependencies = [
"transformers >= 4.28.0",
"torch == 2.0.1",
"uvicorn == 0.17.6",
- "aiohttp == 3.8.5",
+ "aiohttp == 3.8.6",
"langchain >= 0.0.331",
"requests >= 2.26.0",
"bs4 >= 0.0.1",
diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py
index 4b9b54ef..951bf632 100644
--- a/src/database/adapters/__init__.py
+++ b/src/database/adapters/__init__.py
@@ -1,8 +1,8 @@
import math
-from typing import Optional, Type, TypeVar, List
-from datetime import date, datetime, timedelta
+from typing import Optional, Type, List
+from datetime import date, datetime
import secrets
-from typing import Type, TypeVar, List
+from typing import Type, List
from datetime import date, timezone
from django.db import models
@@ -11,10 +11,6 @@ from pgvector.django import CosineDistance
from django.db.models.manager import BaseManager
from django.db.models import Q
from torch import Tensor
-from pgvector.django import CosineDistance
-from django.db.models.manager import BaseManager
-from django.db.models import Q
-from torch import Tensor
# Import sync_to_async from Django Channels
from asgiref.sync import sync_to_async
@@ -31,6 +27,7 @@ from database.models import (
GithubRepoConfig,
Conversation,
ChatModelOptions,
+ SearchModelConfig,
Subscription,
UserConversationConfig,
OpenAIProcessorConversationConfig,
@@ -41,15 +38,6 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.date_filter import DateFilter
-ModelType = TypeVar("ModelType", bound=models.Model)
-
-
-async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType:
- instance = await model_class.objects.filter(id=id).afirst()
- if not instance:
- raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found")
- return instance
-
async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst()
@@ -65,9 +53,7 @@ async def create_khoj_token(user: KhojUser, name=None):
"Create Khoj API key for user"
token = f"kk-{secrets.token_urlsafe(32)}"
name = name or f"{generate_random_name().title()}"
- api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name)
- await api_config.asave()
- return api_config
+ return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
def get_khoj_tokens(user: KhojUser):
@@ -83,13 +69,16 @@ async def delete_khoj_token(user: KhojUser, token: str):
async def get_or_create_user(token: dict) -> KhojUser:
user = await get_user_by_token(token)
if not user:
- user = await create_google_user(token)
+ user = await create_user_by_google_token(token)
return user
-async def create_google_user(token: dict) -> KhojUser:
- user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email"))
+async def create_user_by_google_token(token: dict) -> KhojUser:
+ user, _ = await KhojUser.objects.filter(email=token.get("email")).aupdate_or_create(
+ defaults={"username": token.get("email"), "email": token.get("email")}
+ )
await user.asave()
+
await GoogleUser.objects.acreate(
sub=token.get("sub"),
azp=token.get("azp"),
@@ -220,6 +209,14 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config
+def get_or_create_search_model():
+ search_model = SearchModelConfig.objects.filter().first()
+ if not search_model:
+ search_model = SearchModelConfig.objects.create()
+
+ return search_model
+
+
class ConversationAdapters:
@staticmethod
def get_conversation_by_user(user: KhojUser):
diff --git a/src/database/admin.py b/src/database/admin.py
index 03c2ca42..8d2130ba 100644
--- a/src/database/admin.py
+++ b/src/database/admin.py
@@ -8,6 +8,7 @@ from database.models import (
ChatModelOptions,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
+ SearchModelConfig,
Subscription,
)
@@ -16,4 +17,5 @@ admin.site.register(KhojUser, UserAdmin)
admin.site.register(ChatModelOptions)
admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig)
+admin.site.register(SearchModelConfig)
admin.site.register(Subscription)
diff --git a/src/database/migrations/0017_searchmodel.py b/src/database/migrations/0017_searchmodel.py
new file mode 100644
index 00000000..f150e12b
--- /dev/null
+++ b/src/database/migrations/0017_searchmodel.py
@@ -0,0 +1,32 @@
+# Generated by Django 4.2.5 on 2023-11-14 23:25
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0016_alter_subscription_renewal_date"),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name="SearchModel",
+ fields=[
+ ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
+ ("created_at", models.DateTimeField(auto_now_add=True)),
+ ("updated_at", models.DateTimeField(auto_now=True)),
+ ("name", models.CharField(default="default", max_length=200)),
+ ("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
+ ("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
+ (
+ "cross_encoder",
+ models.CharField(
+ blank=True, default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200, null=True
+ ),
+ ),
+ ],
+ options={
+ "abstract": False,
+ },
+ ),
+ ]
diff --git a/src/database/migrations/0018_searchmodelconfig_delete_searchmodel.py b/src/database/migrations/0018_searchmodelconfig_delete_searchmodel.py
new file mode 100644
index 00000000..a8100370
--- /dev/null
+++ b/src/database/migrations/0018_searchmodelconfig_delete_searchmodel.py
@@ -0,0 +1,30 @@
+# Generated by Django 4.2.5 on 2023-11-16 01:13
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0017_searchmodel"),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name="SearchModelConfig",
+ fields=[
+ ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
+ ("created_at", models.DateTimeField(auto_now_add=True)),
+ ("updated_at", models.DateTimeField(auto_now=True)),
+ ("name", models.CharField(default="default", max_length=200)),
+ ("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
+ ("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
+ ("cross_encoder", models.CharField(default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200)),
+ ],
+ options={
+ "abstract": False,
+ },
+ ),
+ migrations.DeleteModel(
+ name="SearchModel",
+ ),
+ ]
diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py
index 437d86ed..92848e5c 100644
--- a/src/database/models/__init__.py
+++ b/src/database/models/__init__.py
@@ -102,6 +102,16 @@ class LocalPlaintextConfig(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
+class SearchModelConfig(BaseModel):
+ class ModelType(models.TextChoices):
+ TEXT = "text"
+
+ name = models.CharField(max_length=200, default="default")
+ model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
+ bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
+ cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
+
+
class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200)
diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html
index a089939d..ebf93195 100644
--- a/src/interface/desktop/chat.html
+++ b/src/interface/desktop/chat.html
@@ -328,7 +328,15 @@
.then(data => {
if (data.detail) {
// If the server returns a 500 error with detail, render a setup hint.
- renderMessage("Hi 👋🏾, to get started you have two options:
- Use OpenAI:
- Get your OpenAI API key
- Save it in the Khoj chat settings
- Click Configure on the Khoj settings page
- Enable offline chat:
- Go to the Khoj settings page and enable offline chat
", "khoj");
+ first_run_message = `Hi 👋🏾, to get started:
+
+ - Generate an API token in the Khoj Web settings
+ - Paste it into the API Key field in the Khoj Desktop settings
+
`
+ .trim()
+ .replace(/(\r\n|\n|\r)/gm, "");
+
+ renderMessage(first_run_message, "khoj");
// Disable chat input field and update placeholder text
document.getElementById("chat-input").setAttribute("disabled", "disabled");
diff --git a/src/interface/desktop/main.js b/src/interface/desktop/main.js
index 9a42bc5f..eb355a5f 100644
--- a/src/interface/desktop/main.js
+++ b/src/interface/desktop/main.js
@@ -396,6 +396,14 @@ app.whenReady().then(() => {
event.reply('update-state', arg);
});
+ ipcMain.on('navigate', (event, page) => {
+ win.loadFile(page);
+ });
+
+ ipcMain.on('navigateToWebApp', (event, page) => {
+ shell.openExternal(`${store.get('hostURL')}/${page}`);
+ });
+
ipcMain.handle('getFiles', getFiles);
ipcMain.handle('getFolders', getFolders);
diff --git a/src/interface/desktop/package.json b/src/interface/desktop/package.json
index d74e831a..7ee6c7b0 100644
--- a/src/interface/desktop/package.json
+++ b/src/interface/desktop/package.json
@@ -10,14 +10,14 @@
"main": "main.js",
"private": false,
"devDependencies": {
- "electron": "25.8.1"
+ "electron": "25.8.4"
},
"scripts": {
"start": "yarn electron ."
},
"dependencies": {
"@todesktop/runtime": "^1.3.0",
- "axios": "^1.5.0",
+ "axios": "^1.6.0",
"cron": "^2.4.3",
"electron-store": "^8.1.0",
"fs": "^0.0.1-security"
diff --git a/src/interface/desktop/preload.js b/src/interface/desktop/preload.js
index eb5a6cc2..1d4c6ec0 100644
--- a/src/interface/desktop/preload.js
+++ b/src/interface/desktop/preload.js
@@ -57,3 +57,8 @@ contextBridge.exposeInMainWorld('tokenAPI', {
contextBridge.exposeInMainWorld('appInfoAPI', {
getInfo: (callback) => ipcRenderer.on('appInfo', callback)
})
+
+contextBridge.exposeInMainWorld('navigateAPI', {
+ navigateToSettings: () => ipcRenderer.send('navigate', 'config.html'),
+ navigateToWebSettings: () => ipcRenderer.send('navigateToWebApp', 'config'),
+})
diff --git a/src/interface/desktop/yarn.lock b/src/interface/desktop/yarn.lock
index 8591b00d..57583e13 100644
--- a/src/interface/desktop/yarn.lock
+++ b/src/interface/desktop/yarn.lock
@@ -163,10 +163,10 @@ atomically@^1.7.0:
resolved "https://registry.yarnpkg.com/atomically/-/atomically-1.7.0.tgz#c07a0458432ea6dbc9a3506fffa424b48bccaafe"
integrity sha512-Xcz9l0z7y9yQ9rdDaxlmaI4uJHf/T8g9hOEzJcsEqX2SjCj4J20uK7+ldkDHMbpJDK76wF7xEIgxc/vSlsfw5w==
-axios@^1.5.0:
- version "1.5.0"
- resolved "https://registry.yarnpkg.com/axios/-/axios-1.5.0.tgz#f02e4af823e2e46a9768cfc74691fdd0517ea267"
- integrity sha512-D4DdjDo5CY50Qms0qGQTTw6Q44jl7zRwY7bthds06pUGfChBCTcQs+N743eFWGEd6pRTMd6A+I87aWyFV5wiZQ==
+axios@^1.6.0:
+ version "1.6.2"
+ resolved "https://registry.yarnpkg.com/axios/-/axios-1.6.2.tgz#de67d42c755b571d3e698df1b6504cde9b0ee9f2"
+ integrity sha512-7i24Ri4pmDRfJTR7LDBhsOTtcm+9kjX5WiY1X3wIisx6G9So3pfMkEiU7emUBe46oceVImccTEM3k6C5dbVW8A==
dependencies:
follow-redirects "^1.15.0"
form-data "^4.0.0"
@@ -379,10 +379,10 @@ electron-updater@^4.6.1:
lodash.isequal "^4.5.0"
semver "^7.3.5"
-electron@25.8.1:
- version "25.8.1"
- resolved "https://registry.yarnpkg.com/electron/-/electron-25.8.1.tgz#092fab5a833db4d9240d4d6f36218cf7ca954f86"
- integrity sha512-GtcP1nMrROZfFg0+mhyj1hamrHvukfF6of2B/pcWxmWkd5FVY1NJib0tlhiorFZRzQN5Z+APLPr7aMolt7i2AQ==
+electron@25.8.4:
+ version "25.8.4"
+ resolved "https://registry.yarnpkg.com/electron/-/electron-25.8.4.tgz#b50877aac7d96323920437baf309ad86382cb455"
+ integrity sha512-hUYS3RGdaa6E1UWnzeGnsdsBYOggwMMg4WGxNGvAoWtmRrr6J1BsjFW/yRq4WsJHJce2HdzQXtz4OGXV6yUCLg==
dependencies:
"@electron/get" "^2.0.0"
"@types/node" "^18.11.18"
diff --git a/src/khoj/configure.py b/src/khoj/configure.py
index 9fb1f019..5ed92727 100644
--- a/src/khoj/configure.py
+++ b/src/khoj/configure.py
@@ -3,7 +3,6 @@ import logging
import json
from enum import Enum
from typing import Optional
-from fastapi import Request
import requests
import os
@@ -21,15 +20,16 @@ from starlette.authentication import (
)
# Internal Packages
+from database.models import KhojUser, Subscription
+from database.adapters import get_all_users, get_or_create_search_model
+from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
+from khoj.routers.indexer import configure_content, load_content, configure_search
from khoj.utils import constants, state
from khoj.utils.config import (
SearchType,
)
from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig
-from khoj.routers.indexer import configure_content, load_content, configure_search
-from database.models import KhojUser, Subscription
-from database.adapters import get_all_users
logger = logging.getLogger(__name__)
@@ -113,14 +113,13 @@ def configure_server(
# Initialize Search Models from Config and initialize content
try:
- state.config_lock.acquire()
- state.SearchType = configure_search_types(state.config)
+ state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
+ state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder)
+ state.SearchType = configure_search_types()
state.search_models = configure_search(state.search_models, state.config.search_type)
initialize_content(regenerate, search_type, init, user)
except Exception as e:
raise e
- finally:
- state.config_lock.release()
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
@@ -192,7 +191,7 @@ def update_search_index():
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)
-def configure_search_types(config: FullConfig):
+def configure_search_types():
# Extract core search types
core_search_types = {e.name: e.value for e in SearchType}
diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index bd4870e4..82e3233d 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -327,7 +327,7 @@ To get started, just start typing below. You can also type / to see a list of co
.then(data => {
if (data.detail) {
// If the server returns a 500 error with detail, render a setup hint.
- renderMessage("Hi 👋🏾, to get started you have two options:- Use OpenAI:
- Get your OpenAI API key
- Save it in the Khoj chat settings
- Click Configure on the Khoj settings page
- Enable offline chat:
- Go to the Khoj settings page and enable offline chat
", "khoj");
+ renderMessage("Hi 👋🏾, to start chatting add available chat models options via the Django Admin panel on the Server", "khoj");
// Disable chat input field and update placeholder text
document.getElementById("chat-input").setAttribute("disabled", "disabled");
diff --git a/src/khoj/migrations/migrate_server_pg.py b/src/khoj/migrations/migrate_server_pg.py
index 9a34e379..434e27d7 100644
--- a/src/khoj/migrations/migrate_server_pg.py
+++ b/src/khoj/migrations/migrate_server_pg.py
@@ -30,7 +30,7 @@ search-type:
encoder: sentence-transformers/all-MiniLM-L6-v2
encoder-type: null
model-directory: ~/.khoj/search/symmetric
-version: 0.12.4
+version: 0.14.0
The new version will looks like this:
@@ -53,11 +53,7 @@ search-type:
asymmetric:
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
encoder: sentence-transformers/multi-qa-MiniLM-L6-cos-v1
- image:
- encoder: sentence-transformers/clip-ViT-B-32
- encoder-type: null
- model-directory: /Users/si/.khoj/search/image
-version: 0.12.4
+version: 0.15.0
"""
import logging
@@ -68,6 +64,7 @@ from database.models import (
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
ChatModelOptions,
+ SearchModelConfig,
)
logger = logging.getLogger(__name__)
@@ -87,6 +84,19 @@ def migrate_server_pg(args):
if raw_config is None:
return args
+ if "search-type" in raw_config and raw_config["search-type"]:
+ if "asymmetric" in raw_config["search-type"]:
+ # Delete all existing search models
+ SearchModelConfig.objects.filter(model_type=SearchModelConfig.ModelType.TEXT).delete()
+ # Create new search model from existing Khoj YAML config
+ asymmetric_search = raw_config["search-type"]["asymmetric"]
+ SearchModelConfig.objects.create(
+ name="default",
+ model_type=SearchModelConfig.ModelType.TEXT,
+ bi_encoder=asymmetric_search.get("encoder"),
+ cross_encoder=asymmetric_search.get("cross-encoder"),
+ )
+
if "processor" in raw_config and raw_config["processor"] and "conversation" in raw_config["processor"]:
processor_conversation = raw_config["processor"]["conversation"]
diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py
index a4daa24f..392d402f 100644
--- a/src/khoj/processor/embeddings.py
+++ b/src/khoj/processor/embeddings.py
@@ -1,16 +1,17 @@
from typing import List
from sentence_transformers import SentenceTransformer, CrossEncoder
+from torch import nn
from khoj.utils.helpers import get_device
from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel:
- def __init__(self):
+ def __init__(self, model_name: str = "thenlper/gte-small"):
self.encode_kwargs = {"normalize_embeddings": True}
self.model_kwargs = {"device": get_device()}
- self.model_name = "thenlper/gte-small"
+ self.model_name = model_name
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def embed_query(self, query):
@@ -21,11 +22,11 @@ class EmbeddingsModel:
class CrossEncoderModel:
- def __init__(self):
- self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
+ def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
+ self.model_name = model_name
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
- def predict(self, query, hits: List[SearchResponse]):
- cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
- cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
+ def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
+ cross_inp = [[query, hit.additional[key]] for hit in hits]
+ cross_scores = self.cross_encoder_model.predict(cross_inp, activation_fct=nn.Sigmoid())
return cross_scores
diff --git a/src/khoj/processor/text_to_entries.py b/src/khoj/processor/text_to_entries.py
index 3d79e02e..ac42105a 100644
--- a/src/khoj/processor/text_to_entries.py
+++ b/src/khoj/processor/text_to_entries.py
@@ -6,12 +6,12 @@ import logging
import uuid
from tqdm import tqdm
from typing import Callable, List, Tuple, Set, Any
+from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, timer, batcher
# Internal Packages
from khoj.utils.rawconfig import Entry
-from khoj.processor.embeddings import EmbeddingsModel
from khoj.search_filter.date_filter import DateFilter
from database.models import KhojUser, Entry as DbEntry, EntryDates
from database.adapters import EntryAdapters
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
class TextToEntries(ABC):
def __init__(self, config: Any = None):
- self.embeddings_model = EmbeddingsModel()
+ self.embeddings_model = state.embeddings_model
self.config = config
self.date_filter = DateFilter()
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 190fc260..66d7ea53 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -376,7 +376,7 @@ async def search(
# initialize variables
user_query = q.strip()
results_count = n or 5
- max_distance = max_distance if max_distance is not None else math.inf
+ max_distance = max_distance or math.inf
search_futures: List[concurrent.futures.Future] = []
# return cached results, if available
@@ -581,7 +581,7 @@ async def chat(
request: Request,
q: str,
n: Optional[int] = 5,
- d: Optional[float] = 0.15,
+ d: Optional[float] = 0.18,
client: Optional[str] = None,
stream: Optional[bool] = False,
user_agent: Optional[str] = Header(None),
diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py
index 2c013bc8..a9a88325 100644
--- a/src/khoj/routers/auth.py
+++ b/src/khoj/routers/auth.py
@@ -16,6 +16,7 @@ from google.auth.transport import requests as google_requests
# Internal Packages
from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token
+from database.models import KhojApiUser
from khoj.routers.helpers import update_telemetry_state
from khoj.utils import state
@@ -51,12 +52,16 @@ async def login(request: Request):
@auth_router.post("/token")
@requires(["authenticated"], redirect="login_page")
-async def generate_token(request: Request, token_name: Optional[str] = None) -> str:
+async def generate_token(request: Request, token_name: Optional[str] = None):
"Generate API token for given user"
if token_name:
- return await create_khoj_token(user=request.user.object, name=token_name)
+ token = await create_khoj_token(user=request.user.object, name=token_name)
else:
- return await create_khoj_token(user=request.user.object)
+ token = await create_khoj_token(user=request.user.object)
+ return {
+ "token": token.token,
+ "name": token.name,
+ }
@auth_router.get("/token")
diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py
index 2b99ed66..f07eb580 100644
--- a/src/khoj/search_type/text_search.py
+++ b/src/khoj/search_type/text_search.py
@@ -2,7 +2,7 @@
import logging
import math
from pathlib import Path
-from typing import List, Tuple, Type, Union, Dict
+from typing import List, Tuple, Type, Union
# External Packages
import torch
diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py
index 098ae35e..91f5f0ce 100644
--- a/src/khoj/utils/state.py
+++ b/src/khoj/utils/state.py
@@ -5,21 +5,20 @@ from typing import List, Dict
from collections import defaultdict
# External Packages
-import torch
from pathlib import Path
+from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
# Internal Packages
from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU, get_device
from khoj.utils.rawconfig import FullConfig
-from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
# Application Global State
config = FullConfig()
search_models = SearchModels()
-embeddings_model = EmbeddingsModel()
-cross_encoder_model = CrossEncoderModel()
+embeddings_model: EmbeddingsModel = None
+cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex()
gpt4all_processor_config: GPT4AllProcessorModel = None
config_file: Path = None
@@ -28,7 +27,6 @@ host: str = None
port: int = None
cli_args: List[str] = None
query_cache: Dict[str, LRU] = defaultdict(LRU)
-config_lock = threading.Lock()
chat_lock = threading.Lock()
SearchType = utils_config.SearchType
telemetry: List[Dict[str, str]] = []
diff --git a/tests/conftest.py b/tests/conftest.py
index 95fa9a99..d90bae95 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -8,11 +8,13 @@ from fastapi import FastAPI
import os
from fastapi import FastAPI
+
app = FastAPI()
# Internal Packages
from khoj.configure import configure_routes, configure_search_types, configure_middleware
+from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchModels
@@ -54,6 +56,9 @@ def enable_db_access_for_all_tests(db):
@pytest.fixture(scope="session")
def search_config() -> SearchConfig:
+ state.embeddings_model = EmbeddingsModel()
+ state.cross_encoder_model = CrossEncoderModel()
+
model_dir = resolve_absolute_path("~/.khoj/search")
model_dir.mkdir(parents=True, exist_ok=True)
search_config = SearchConfig()
@@ -222,7 +227,7 @@ def md_content_config():
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.search_type = search_config
- state.SearchType = configure_search_types(state.config)
+ state.SearchType = configure_search_types()
LocalMarkdownConfig.objects.create(
input_files=None,
@@ -256,7 +261,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.search_type = search_config
- state.SearchType = configure_search_types(state.config)
+ state.SearchType = configure_search_types()
# Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"):
@@ -291,7 +296,9 @@ def client(
):
state.config.content_type = content_config
state.config.search_type = search_config
- state.SearchType = configure_search_types(state.config)
+ state.SearchType = configure_search_types()
+ state.embeddings_model = EmbeddingsModel()
+ state.cross_encoder_model = CrossEncoderModel()
# These lines help us Mock the Search models for these search types
state.search_models.image_search = image_search.initialize_model(search_config.image)
@@ -323,7 +330,7 @@ def client(
def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.search_type = search_config
- state.SearchType = configure_search_types(state.config)
+ state.SearchType = configure_search_types()
LocalMarkdownConfig.objects.create(
input_files=None,
diff --git a/tests/helpers.py b/tests/helpers.py
index 03f3f9c7..079eb475 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -7,6 +7,7 @@ from database.models import (
ChatModelOptions,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig,
+ SearchModelConfig,
UserConversationConfig,
Conversation,
Subscription,
@@ -71,6 +72,16 @@ class ConversationFactory(factory.django.DjangoModelFactory):
user = factory.SubFactory(UserFactory)
+class SearchModelFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = SearchModelConfig
+
+ name = "default"
+ model_type = "text"
+ bi_encoder = "thenlper/gte-small"
+ cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2"
+
+
class SubscriptionFactory(factory.django.DjangoModelFactory):
class Meta:
model = Subscription
diff --git a/tests/test_client.py b/tests/test_client.py
index 1894577c..f642a727 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -173,7 +173,6 @@ def test_regenerate_with_github_fails_without_pat(client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
-@pytest.mark.skip(reason="Flaky test on parallel test runs")
def test_get_configured_types_via_api(client, sample_org_data):
# Act
text_search.setup(OrgToEntries, sample_org_data, regenerate=False)
@@ -203,10 +202,10 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
@pytest.mark.django_db(transaction=True)
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
# Arrange
- state.SearchType = configure_search_types(config)
- original_config = state.config.content_type
- state.config.content_type = None
state.anonymous_mode = True
+ if state.config and state.config.content_type:
+ state.config.content_type = None
+ state.search_models = configure_search_types()
configure_routes(fastapi_app)
client = TestClient(fastapi_app)
@@ -218,9 +217,6 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
assert response.status_code == 200
assert response.json() == ["all"]
- # Restore
- state.config.content_type = original_config
-
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@@ -259,13 +255,30 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa
user_query = quote("How to git install application?")
# Act
- response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true", headers=headers)
+ response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
# Assert
assert response.status_code == 200
- # assert actual_data contains "Khoj via Emacs" entry
+
+ assert len(response.json()) == 1, "Expected only 1 result"
search_result = response.json()[0]["entry"]
- assert "git clone https://github.com/khoj-ai/khoj" in search_result
+ assert "git clone https://github.com/khoj-ai/khoj" in search_result, "Expected 'git clone' in search result"
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
+def test_notes_search_no_results(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
+ # Arrange
+ headers = {"Authorization": "Bearer kk-secret"}
+ text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
+ user_query = quote("How to find my goat?")
+
+ # Act
+ response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
+
+ # Assert
+ assert response.status_code == 200
+ assert response.json() == [], "Expected no results"
# ----------------------------------------------------------------------------------------------------