diff --git a/.github/workflows/dockerize_dev.yml b/.github/workflows/dockerize_dev.yml new file mode 100644 index 00000000..1d037ce7 --- /dev/null +++ b/.github/workflows/dockerize_dev.yml @@ -0,0 +1,43 @@ +name: dockerize-dev + +on: + pull_request: + paths: + - src/khoj/** + - config/** + - pyproject.toml + - prod.Dockerfile + - .github/workflows/dockerize_dev.yml + workflow_dispatch: + +env: + DOCKER_IMAGE_TAG: 'dev' + +jobs: + build: + name: Build Production Docker Image, Push to Container Registry + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + + - name: Login to GitHub Container Registry + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.PAT }} + + - name: 📦 Build and Push Docker Image + uses: docker/build-push-action@v2 + with: + context: . + file: prod.Dockerfile + platforms: linux/amd64 + push: true + tags: ghcr.io/${{ github.repository }}:${{ env.DOCKER_IMAGE_TAG }} + build-args: | + PORT=42110 diff --git a/pyproject.toml b/pyproject.toml index 10e44ac0..15a4c8e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ dependencies = [ "transformers >= 4.28.0", "torch == 2.0.1", "uvicorn == 0.17.6", - "aiohttp == 3.8.5", + "aiohttp == 3.8.6", "langchain >= 0.0.331", "requests >= 2.26.0", "bs4 >= 0.0.1", diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 4b9b54ef..951bf632 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -1,8 +1,8 @@ import math -from typing import Optional, Type, TypeVar, List -from datetime import date, datetime, timedelta +from typing import Optional, Type, List +from datetime import date, datetime import secrets -from typing import Type, TypeVar, List +from typing import Type, List from datetime import date, timezone from django.db import models @@ -11,10 +11,6 @@ from pgvector.django import CosineDistance from django.db.models.manager import BaseManager from django.db.models import Q from torch import Tensor -from pgvector.django import CosineDistance -from django.db.models.manager import BaseManager -from django.db.models import Q -from torch import Tensor # Import sync_to_async from Django Channels from asgiref.sync import sync_to_async @@ -31,6 +27,7 @@ from database.models import ( GithubRepoConfig, Conversation, ChatModelOptions, + SearchModelConfig, Subscription, UserConversationConfig, OpenAIProcessorConversationConfig, @@ -41,15 +38,6 @@ from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.date_filter import DateFilter -ModelType = TypeVar("ModelType", bound=models.Model) - - -async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType: - instance = await model_class.objects.filter(id=id).afirst() - if not instance: - raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found") - return instance - async def set_notion_config(token: str, user: KhojUser): notion_config = await NotionConfig.objects.filter(user=user).afirst() @@ -65,9 +53,7 @@ async def create_khoj_token(user: KhojUser, name=None): "Create Khoj API key for user" token = f"kk-{secrets.token_urlsafe(32)}" name = name or f"{generate_random_name().title()}" - api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name) - await api_config.asave() - return api_config + return await KhojApiUser.objects.acreate(token=token, user=user, name=name) def get_khoj_tokens(user: KhojUser): @@ -83,13 +69,16 @@ async def delete_khoj_token(user: KhojUser, token: str): async def get_or_create_user(token: dict) -> KhojUser: user = await get_user_by_token(token) if not user: - user = await create_google_user(token) + user = await create_user_by_google_token(token) return user -async def create_google_user(token: dict) -> KhojUser: - user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email")) +async def create_user_by_google_token(token: dict) -> KhojUser: + user, _ = await KhojUser.objects.filter(email=token.get("email")).aupdate_or_create( + defaults={"username": token.get("email"), "email": token.get("email")} + ) await user.asave() + await GoogleUser.objects.acreate( sub=token.get("sub"), azp=token.get("azp"), @@ -220,6 +209,14 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): return config +def get_or_create_search_model(): + search_model = SearchModelConfig.objects.filter().first() + if not search_model: + search_model = SearchModelConfig.objects.create() + + return search_model + + class ConversationAdapters: @staticmethod def get_conversation_by_user(user: KhojUser): diff --git a/src/database/admin.py b/src/database/admin.py index 03c2ca42..8d2130ba 100644 --- a/src/database/admin.py +++ b/src/database/admin.py @@ -8,6 +8,7 @@ from database.models import ( ChatModelOptions, OpenAIProcessorConversationConfig, OfflineChatProcessorConversationConfig, + SearchModelConfig, Subscription, ) @@ -16,4 +17,5 @@ admin.site.register(KhojUser, UserAdmin) admin.site.register(ChatModelOptions) admin.site.register(OpenAIProcessorConversationConfig) admin.site.register(OfflineChatProcessorConversationConfig) +admin.site.register(SearchModelConfig) admin.site.register(Subscription) diff --git a/src/database/migrations/0017_searchmodel.py b/src/database/migrations/0017_searchmodel.py new file mode 100644 index 00000000..f150e12b --- /dev/null +++ b/src/database/migrations/0017_searchmodel.py @@ -0,0 +1,32 @@ +# Generated by Django 4.2.5 on 2023-11-14 23:25 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0016_alter_subscription_renewal_date"), + ] + + operations = [ + migrations.CreateModel( + name="SearchModel", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("name", models.CharField(default="default", max_length=200)), + ("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)), + ("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)), + ( + "cross_encoder", + models.CharField( + blank=True, default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200, null=True + ), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/database/migrations/0018_searchmodelconfig_delete_searchmodel.py b/src/database/migrations/0018_searchmodelconfig_delete_searchmodel.py new file mode 100644 index 00000000..a8100370 --- /dev/null +++ b/src/database/migrations/0018_searchmodelconfig_delete_searchmodel.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.5 on 2023-11-16 01:13 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0017_searchmodel"), + ] + + operations = [ + migrations.CreateModel( + name="SearchModelConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("name", models.CharField(default="default", max_length=200)), + ("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)), + ("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)), + ("cross_encoder", models.CharField(default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200)), + ], + options={ + "abstract": False, + }, + ), + migrations.DeleteModel( + name="SearchModel", + ), + ] diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index 437d86ed..92848e5c 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -102,6 +102,16 @@ class LocalPlaintextConfig(BaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) +class SearchModelConfig(BaseModel): + class ModelType(models.TextChoices): + TEXT = "text" + + name = models.CharField(max_length=200, default="default") + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT) + bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") + cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") + + class OpenAIProcessorConversationConfig(BaseModel): api_key = models.CharField(max_length=200) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index a089939d..ebf93195 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -328,7 +328,15 @@ .then(data => { if (data.detail) { // If the server returns a 500 error with detail, render a setup hint. - renderMessage("Hi 👋🏾, to get started you have two options:
  1. Use OpenAI:
    1. Get your OpenAI API key
    2. Save it in the Khoj chat settings
    3. Click Configure on the Khoj settings page
  2. Enable offline chat:
    1. Go to the Khoj settings page and enable offline chat
", "khoj"); + first_run_message = `Hi 👋🏾, to get started: +
    +
  1. Generate an API token in the Khoj Web settings
  2. +
  3. Paste it into the API Key field in the Khoj Desktop settings
  4. +
` + .trim() + .replace(/(\r\n|\n|\r)/gm, ""); + + renderMessage(first_run_message, "khoj"); // Disable chat input field and update placeholder text document.getElementById("chat-input").setAttribute("disabled", "disabled"); diff --git a/src/interface/desktop/main.js b/src/interface/desktop/main.js index 9a42bc5f..eb355a5f 100644 --- a/src/interface/desktop/main.js +++ b/src/interface/desktop/main.js @@ -396,6 +396,14 @@ app.whenReady().then(() => { event.reply('update-state', arg); }); + ipcMain.on('navigate', (event, page) => { + win.loadFile(page); + }); + + ipcMain.on('navigateToWebApp', (event, page) => { + shell.openExternal(`${store.get('hostURL')}/${page}`); + }); + ipcMain.handle('getFiles', getFiles); ipcMain.handle('getFolders', getFolders); diff --git a/src/interface/desktop/package.json b/src/interface/desktop/package.json index d74e831a..7ee6c7b0 100644 --- a/src/interface/desktop/package.json +++ b/src/interface/desktop/package.json @@ -10,14 +10,14 @@ "main": "main.js", "private": false, "devDependencies": { - "electron": "25.8.1" + "electron": "25.8.4" }, "scripts": { "start": "yarn electron ." }, "dependencies": { "@todesktop/runtime": "^1.3.0", - "axios": "^1.5.0", + "axios": "^1.6.0", "cron": "^2.4.3", "electron-store": "^8.1.0", "fs": "^0.0.1-security" diff --git a/src/interface/desktop/preload.js b/src/interface/desktop/preload.js index eb5a6cc2..1d4c6ec0 100644 --- a/src/interface/desktop/preload.js +++ b/src/interface/desktop/preload.js @@ -57,3 +57,8 @@ contextBridge.exposeInMainWorld('tokenAPI', { contextBridge.exposeInMainWorld('appInfoAPI', { getInfo: (callback) => ipcRenderer.on('appInfo', callback) }) + +contextBridge.exposeInMainWorld('navigateAPI', { + navigateToSettings: () => ipcRenderer.send('navigate', 'config.html'), + navigateToWebSettings: () => ipcRenderer.send('navigateToWebApp', 'config'), +}) diff --git a/src/interface/desktop/yarn.lock b/src/interface/desktop/yarn.lock index 8591b00d..57583e13 100644 --- a/src/interface/desktop/yarn.lock +++ b/src/interface/desktop/yarn.lock @@ -163,10 +163,10 @@ atomically@^1.7.0: resolved "https://registry.yarnpkg.com/atomically/-/atomically-1.7.0.tgz#c07a0458432ea6dbc9a3506fffa424b48bccaafe" integrity sha512-Xcz9l0z7y9yQ9rdDaxlmaI4uJHf/T8g9hOEzJcsEqX2SjCj4J20uK7+ldkDHMbpJDK76wF7xEIgxc/vSlsfw5w== -axios@^1.5.0: - version "1.5.0" - resolved "https://registry.yarnpkg.com/axios/-/axios-1.5.0.tgz#f02e4af823e2e46a9768cfc74691fdd0517ea267" - integrity sha512-D4DdjDo5CY50Qms0qGQTTw6Q44jl7zRwY7bthds06pUGfChBCTcQs+N743eFWGEd6pRTMd6A+I87aWyFV5wiZQ== +axios@^1.6.0: + version "1.6.2" + resolved "https://registry.yarnpkg.com/axios/-/axios-1.6.2.tgz#de67d42c755b571d3e698df1b6504cde9b0ee9f2" + integrity sha512-7i24Ri4pmDRfJTR7LDBhsOTtcm+9kjX5WiY1X3wIisx6G9So3pfMkEiU7emUBe46oceVImccTEM3k6C5dbVW8A== dependencies: follow-redirects "^1.15.0" form-data "^4.0.0" @@ -379,10 +379,10 @@ electron-updater@^4.6.1: lodash.isequal "^4.5.0" semver "^7.3.5" -electron@25.8.1: - version "25.8.1" - resolved "https://registry.yarnpkg.com/electron/-/electron-25.8.1.tgz#092fab5a833db4d9240d4d6f36218cf7ca954f86" - integrity sha512-GtcP1nMrROZfFg0+mhyj1hamrHvukfF6of2B/pcWxmWkd5FVY1NJib0tlhiorFZRzQN5Z+APLPr7aMolt7i2AQ== +electron@25.8.4: + version "25.8.4" + resolved "https://registry.yarnpkg.com/electron/-/electron-25.8.4.tgz#b50877aac7d96323920437baf309ad86382cb455" + integrity sha512-hUYS3RGdaa6E1UWnzeGnsdsBYOggwMMg4WGxNGvAoWtmRrr6J1BsjFW/yRq4WsJHJce2HdzQXtz4OGXV6yUCLg== dependencies: "@electron/get" "^2.0.0" "@types/node" "^18.11.18" diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 9fb1f019..5ed92727 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -3,7 +3,6 @@ import logging import json from enum import Enum from typing import Optional -from fastapi import Request import requests import os @@ -21,15 +20,16 @@ from starlette.authentication import ( ) # Internal Packages +from database.models import KhojUser, Subscription +from database.adapters import get_all_users, get_or_create_search_model +from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel +from khoj.routers.indexer import configure_content, load_content, configure_search from khoj.utils import constants, state from khoj.utils.config import ( SearchType, ) from khoj.utils.fs_syncer import collect_files from khoj.utils.rawconfig import FullConfig -from khoj.routers.indexer import configure_content, load_content, configure_search -from database.models import KhojUser, Subscription -from database.adapters import get_all_users logger = logging.getLogger(__name__) @@ -113,14 +113,13 @@ def configure_server( # Initialize Search Models from Config and initialize content try: - state.config_lock.acquire() - state.SearchType = configure_search_types(state.config) + state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder) + state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder) + state.SearchType = configure_search_types() state.search_models = configure_search(state.search_models, state.config.search_type) initialize_content(regenerate, search_type, init, user) except Exception as e: raise e - finally: - state.config_lock.release() def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None): @@ -192,7 +191,7 @@ def update_search_index(): logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True) -def configure_search_types(config: FullConfig): +def configure_search_types(): # Extract core search types core_search_types = {e.name: e.value for e in SearchType} diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index bd4870e4..82e3233d 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -327,7 +327,7 @@ To get started, just start typing below. You can also type / to see a list of co .then(data => { if (data.detail) { // If the server returns a 500 error with detail, render a setup hint. - renderMessage("Hi 👋🏾, to get started you have two options:
  1. Use OpenAI:
    1. Get your OpenAI API key
    2. Save it in the Khoj chat settings
    3. Click Configure on the Khoj settings page
  2. Enable offline chat:
    1. Go to the Khoj settings page and enable offline chat
", "khoj"); + renderMessage("Hi 👋🏾, to start chatting add available chat models options via the Django Admin panel on the Server", "khoj"); // Disable chat input field and update placeholder text document.getElementById("chat-input").setAttribute("disabled", "disabled"); diff --git a/src/khoj/migrations/migrate_server_pg.py b/src/khoj/migrations/migrate_server_pg.py index 9a34e379..434e27d7 100644 --- a/src/khoj/migrations/migrate_server_pg.py +++ b/src/khoj/migrations/migrate_server_pg.py @@ -30,7 +30,7 @@ search-type: encoder: sentence-transformers/all-MiniLM-L6-v2 encoder-type: null model-directory: ~/.khoj/search/symmetric -version: 0.12.4 +version: 0.14.0 The new version will looks like this: @@ -53,11 +53,7 @@ search-type: asymmetric: cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2 encoder: sentence-transformers/multi-qa-MiniLM-L6-cos-v1 - image: - encoder: sentence-transformers/clip-ViT-B-32 - encoder-type: null - model-directory: /Users/si/.khoj/search/image -version: 0.12.4 +version: 0.15.0 """ import logging @@ -68,6 +64,7 @@ from database.models import ( OpenAIProcessorConversationConfig, OfflineChatProcessorConversationConfig, ChatModelOptions, + SearchModelConfig, ) logger = logging.getLogger(__name__) @@ -87,6 +84,19 @@ def migrate_server_pg(args): if raw_config is None: return args + if "search-type" in raw_config and raw_config["search-type"]: + if "asymmetric" in raw_config["search-type"]: + # Delete all existing search models + SearchModelConfig.objects.filter(model_type=SearchModelConfig.ModelType.TEXT).delete() + # Create new search model from existing Khoj YAML config + asymmetric_search = raw_config["search-type"]["asymmetric"] + SearchModelConfig.objects.create( + name="default", + model_type=SearchModelConfig.ModelType.TEXT, + bi_encoder=asymmetric_search.get("encoder"), + cross_encoder=asymmetric_search.get("cross-encoder"), + ) + if "processor" in raw_config and raw_config["processor"] and "conversation" in raw_config["processor"]: processor_conversation = raw_config["processor"]["conversation"] diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index a4daa24f..392d402f 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -1,16 +1,17 @@ from typing import List from sentence_transformers import SentenceTransformer, CrossEncoder +from torch import nn from khoj.utils.helpers import get_device from khoj.utils.rawconfig import SearchResponse class EmbeddingsModel: - def __init__(self): + def __init__(self, model_name: str = "thenlper/gte-small"): self.encode_kwargs = {"normalize_embeddings": True} self.model_kwargs = {"device": get_device()} - self.model_name = "thenlper/gte-small" + self.model_name = model_name self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) def embed_query(self, query): @@ -21,11 +22,11 @@ class EmbeddingsModel: class CrossEncoderModel: - def __init__(self): - self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2" + def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): + self.model_name = model_name self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device()) - def predict(self, query, hits: List[SearchResponse]): - cross__inp = [[query, hit.additional["compiled"]] for hit in hits] - cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True) + def predict(self, query, hits: List[SearchResponse], key: str = "compiled"): + cross_inp = [[query, hit.additional[key]] for hit in hits] + cross_scores = self.cross_encoder_model.predict(cross_inp, activation_fct=nn.Sigmoid()) return cross_scores diff --git a/src/khoj/processor/text_to_entries.py b/src/khoj/processor/text_to_entries.py index 3d79e02e..ac42105a 100644 --- a/src/khoj/processor/text_to_entries.py +++ b/src/khoj/processor/text_to_entries.py @@ -6,12 +6,12 @@ import logging import uuid from tqdm import tqdm from typing import Callable, List, Tuple, Set, Any +from khoj.utils import state from khoj.utils.helpers import is_none_or_empty, timer, batcher # Internal Packages from khoj.utils.rawconfig import Entry -from khoj.processor.embeddings import EmbeddingsModel from khoj.search_filter.date_filter import DateFilter from database.models import KhojUser, Entry as DbEntry, EntryDates from database.adapters import EntryAdapters @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) class TextToEntries(ABC): def __init__(self, config: Any = None): - self.embeddings_model = EmbeddingsModel() + self.embeddings_model = state.embeddings_model self.config = config self.date_filter = DateFilter() diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 190fc260..66d7ea53 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -376,7 +376,7 @@ async def search( # initialize variables user_query = q.strip() results_count = n or 5 - max_distance = max_distance if max_distance is not None else math.inf + max_distance = max_distance or math.inf search_futures: List[concurrent.futures.Future] = [] # return cached results, if available @@ -581,7 +581,7 @@ async def chat( request: Request, q: str, n: Optional[int] = 5, - d: Optional[float] = 0.15, + d: Optional[float] = 0.18, client: Optional[str] = None, stream: Optional[bool] = False, user_agent: Optional[str] = Header(None), diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py index 2c013bc8..a9a88325 100644 --- a/src/khoj/routers/auth.py +++ b/src/khoj/routers/auth.py @@ -16,6 +16,7 @@ from google.auth.transport import requests as google_requests # Internal Packages from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token +from database.models import KhojApiUser from khoj.routers.helpers import update_telemetry_state from khoj.utils import state @@ -51,12 +52,16 @@ async def login(request: Request): @auth_router.post("/token") @requires(["authenticated"], redirect="login_page") -async def generate_token(request: Request, token_name: Optional[str] = None) -> str: +async def generate_token(request: Request, token_name: Optional[str] = None): "Generate API token for given user" if token_name: - return await create_khoj_token(user=request.user.object, name=token_name) + token = await create_khoj_token(user=request.user.object, name=token_name) else: - return await create_khoj_token(user=request.user.object) + token = await create_khoj_token(user=request.user.object) + return { + "token": token.token, + "name": token.name, + } @auth_router.get("/token") diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 2b99ed66..f07eb580 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -2,7 +2,7 @@ import logging import math from pathlib import Path -from typing import List, Tuple, Type, Union, Dict +from typing import List, Tuple, Type, Union # External Packages import torch diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 098ae35e..91f5f0ce 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -5,21 +5,20 @@ from typing import List, Dict from collections import defaultdict # External Packages -import torch from pathlib import Path +from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel # Internal Packages from khoj.utils import config as utils_config from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel from khoj.utils.helpers import LRU, get_device from khoj.utils.rawconfig import FullConfig -from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel # Application Global State config = FullConfig() search_models = SearchModels() -embeddings_model = EmbeddingsModel() -cross_encoder_model = CrossEncoderModel() +embeddings_model: EmbeddingsModel = None +cross_encoder_model: CrossEncoderModel = None content_index = ContentIndex() gpt4all_processor_config: GPT4AllProcessorModel = None config_file: Path = None @@ -28,7 +27,6 @@ host: str = None port: int = None cli_args: List[str] = None query_cache: Dict[str, LRU] = defaultdict(LRU) -config_lock = threading.Lock() chat_lock = threading.Lock() SearchType = utils_config.SearchType telemetry: List[Dict[str, str]] = [] diff --git a/tests/conftest.py b/tests/conftest.py index 95fa9a99..d90bae95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,11 +8,13 @@ from fastapi import FastAPI import os from fastapi import FastAPI + app = FastAPI() # Internal Packages from khoj.configure import configure_routes, configure_search_types, configure_middleware +from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.search_type import image_search, text_search from khoj.utils.config import SearchModels @@ -54,6 +56,9 @@ def enable_db_access_for_all_tests(db): @pytest.fixture(scope="session") def search_config() -> SearchConfig: + state.embeddings_model = EmbeddingsModel() + state.cross_encoder_model = CrossEncoderModel() + model_dir = resolve_absolute_path("~/.khoj/search") model_dir.mkdir(parents=True, exist_ok=True) search_config = SearchConfig() @@ -222,7 +227,7 @@ def md_content_config(): def chat_client(search_config: SearchConfig, default_user2: KhojUser): # Initialize app state state.config.search_type = search_config - state.SearchType = configure_search_types(state.config) + state.SearchType = configure_search_types() LocalMarkdownConfig.objects.create( input_files=None, @@ -256,7 +261,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser): def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser): # Initialize app state state.config.search_type = search_config - state.SearchType = configure_search_types(state.config) + state.SearchType = configure_search_types() # Initialize Processor from Config if os.getenv("OPENAI_API_KEY"): @@ -291,7 +296,9 @@ def client( ): state.config.content_type = content_config state.config.search_type = search_config - state.SearchType = configure_search_types(state.config) + state.SearchType = configure_search_types() + state.embeddings_model = EmbeddingsModel() + state.cross_encoder_model = CrossEncoderModel() # These lines help us Mock the Search models for these search types state.search_models.image_search = image_search.initialize_model(search_config.image) @@ -323,7 +330,7 @@ def client( def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser): # Initialize app state state.config.search_type = search_config - state.SearchType = configure_search_types(state.config) + state.SearchType = configure_search_types() LocalMarkdownConfig.objects.create( input_files=None, diff --git a/tests/helpers.py b/tests/helpers.py index 03f3f9c7..079eb475 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,6 +7,7 @@ from database.models import ( ChatModelOptions, OfflineChatProcessorConversationConfig, OpenAIProcessorConversationConfig, + SearchModelConfig, UserConversationConfig, Conversation, Subscription, @@ -71,6 +72,16 @@ class ConversationFactory(factory.django.DjangoModelFactory): user = factory.SubFactory(UserFactory) +class SearchModelFactory(factory.django.DjangoModelFactory): + class Meta: + model = SearchModelConfig + + name = "default" + model_type = "text" + bi_encoder = "thenlper/gte-small" + cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2" + + class SubscriptionFactory(factory.django.DjangoModelFactory): class Meta: model = Subscription diff --git a/tests/test_client.py b/tests/test_client.py index 1894577c..f642a727 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -173,7 +173,6 @@ def test_regenerate_with_github_fails_without_pat(client): # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db -@pytest.mark.skip(reason="Flaky test on parallel test runs") def test_get_configured_types_via_api(client, sample_org_data): # Act text_search.setup(OrgToEntries, sample_org_data, regenerate=False) @@ -203,10 +202,10 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser): @pytest.mark.django_db(transaction=True) def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI): # Arrange - state.SearchType = configure_search_types(config) - original_config = state.config.content_type - state.config.content_type = None state.anonymous_mode = True + if state.config and state.config.content_type: + state.config.content_type = None + state.search_models = configure_search_types() configure_routes(fastapi_app) client = TestClient(fastapi_app) @@ -218,9 +217,6 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI): assert response.status_code == 200 assert response.json() == ["all"] - # Restore - state.config.content_type = original_config - # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) @@ -259,13 +255,30 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa user_query = quote("How to git install application?") # Act - response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true", headers=headers) + response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers) # Assert assert response.status_code == 200 - # assert actual_data contains "Khoj via Emacs" entry + + assert len(response.json()) == 1, "Expected only 1 result" search_result = response.json()[0]["entry"] - assert "git clone https://github.com/khoj-ai/khoj" in search_result + assert "git clone https://github.com/khoj-ai/khoj" in search_result, "Expected 'git clone' in search result" + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) +def test_notes_search_no_results(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser): + # Arrange + headers = {"Authorization": "Bearer kk-secret"} + text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user) + user_query = quote("How to find my goat?") + + # Act + response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers) + + # Assert + assert response.status_code == 200 + assert response.json() == [], "Expected no results" # ----------------------------------------------------------------------------------------------------