mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-10 13:26:13 +00:00
Make Search Model Configurable on Server (#544)
- Make search model configurable on server - Update migration script to get search model from `khoj.yml` to Postgres - Update first run message on Khoj Desktop and Web app landing page - Other miscellaneous bug fixes
This commit is contained in:
43
.github/workflows/dockerize_dev.yml
vendored
Normal file
43
.github/workflows/dockerize_dev.yml
vendored
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
name: dockerize-dev
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- src/khoj/**
|
||||||
|
- config/**
|
||||||
|
- pyproject.toml
|
||||||
|
- prod.Dockerfile
|
||||||
|
- .github/workflows/dockerize_dev.yml
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
env:
|
||||||
|
DOCKER_IMAGE_TAG: 'dev'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: Build Production Docker Image, Push to Container Registry
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout Code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v2
|
||||||
|
|
||||||
|
- name: Login to GitHub Container Registry
|
||||||
|
uses: docker/login-action@v2
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.repository_owner }}
|
||||||
|
password: ${{ secrets.PAT }}
|
||||||
|
|
||||||
|
- name: 📦 Build and Push Docker Image
|
||||||
|
uses: docker/build-push-action@v2
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: prod.Dockerfile
|
||||||
|
platforms: linux/amd64
|
||||||
|
push: true
|
||||||
|
tags: ghcr.io/${{ github.repository }}:${{ env.DOCKER_IMAGE_TAG }}
|
||||||
|
build-args: |
|
||||||
|
PORT=42110
|
||||||
@@ -54,7 +54,7 @@ dependencies = [
|
|||||||
"transformers >= 4.28.0",
|
"transformers >= 4.28.0",
|
||||||
"torch == 2.0.1",
|
"torch == 2.0.1",
|
||||||
"uvicorn == 0.17.6",
|
"uvicorn == 0.17.6",
|
||||||
"aiohttp == 3.8.5",
|
"aiohttp == 3.8.6",
|
||||||
"langchain >= 0.0.331",
|
"langchain >= 0.0.331",
|
||||||
"requests >= 2.26.0",
|
"requests >= 2.26.0",
|
||||||
"bs4 >= 0.0.1",
|
"bs4 >= 0.0.1",
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import math
|
import math
|
||||||
from typing import Optional, Type, TypeVar, List
|
from typing import Optional, Type, List
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Type, TypeVar, List
|
from typing import Type, List
|
||||||
from datetime import date, timezone
|
from datetime import date, timezone
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
@@ -11,10 +11,6 @@ from pgvector.django import CosineDistance
|
|||||||
from django.db.models.manager import BaseManager
|
from django.db.models.manager import BaseManager
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from pgvector.django import CosineDistance
|
|
||||||
from django.db.models.manager import BaseManager
|
|
||||||
from django.db.models import Q
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
# Import sync_to_async from Django Channels
|
# Import sync_to_async from Django Channels
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
@@ -31,6 +27,7 @@ from database.models import (
|
|||||||
GithubRepoConfig,
|
GithubRepoConfig,
|
||||||
Conversation,
|
Conversation,
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
|
SearchModelConfig,
|
||||||
Subscription,
|
Subscription,
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
@@ -41,15 +38,6 @@ from khoj.search_filter.word_filter import WordFilter
|
|||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=models.Model)
|
|
||||||
|
|
||||||
|
|
||||||
async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType:
|
|
||||||
instance = await model_class.objects.filter(id=id).afirst()
|
|
||||||
if not instance:
|
|
||||||
raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found")
|
|
||||||
return instance
|
|
||||||
|
|
||||||
|
|
||||||
async def set_notion_config(token: str, user: KhojUser):
|
async def set_notion_config(token: str, user: KhojUser):
|
||||||
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
||||||
@@ -65,9 +53,7 @@ async def create_khoj_token(user: KhojUser, name=None):
|
|||||||
"Create Khoj API key for user"
|
"Create Khoj API key for user"
|
||||||
token = f"kk-{secrets.token_urlsafe(32)}"
|
token = f"kk-{secrets.token_urlsafe(32)}"
|
||||||
name = name or f"{generate_random_name().title()}"
|
name = name or f"{generate_random_name().title()}"
|
||||||
api_config = await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
return await KhojApiUser.objects.acreate(token=token, user=user, name=name)
|
||||||
await api_config.asave()
|
|
||||||
return api_config
|
|
||||||
|
|
||||||
|
|
||||||
def get_khoj_tokens(user: KhojUser):
|
def get_khoj_tokens(user: KhojUser):
|
||||||
@@ -83,13 +69,16 @@ async def delete_khoj_token(user: KhojUser, token: str):
|
|||||||
async def get_or_create_user(token: dict) -> KhojUser:
|
async def get_or_create_user(token: dict) -> KhojUser:
|
||||||
user = await get_user_by_token(token)
|
user = await get_user_by_token(token)
|
||||||
if not user:
|
if not user:
|
||||||
user = await create_google_user(token)
|
user = await create_user_by_google_token(token)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def create_google_user(token: dict) -> KhojUser:
|
async def create_user_by_google_token(token: dict) -> KhojUser:
|
||||||
user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email"))
|
user, _ = await KhojUser.objects.filter(email=token.get("email")).aupdate_or_create(
|
||||||
|
defaults={"username": token.get("email"), "email": token.get("email")}
|
||||||
|
)
|
||||||
await user.asave()
|
await user.asave()
|
||||||
|
|
||||||
await GoogleUser.objects.acreate(
|
await GoogleUser.objects.acreate(
|
||||||
sub=token.get("sub"),
|
sub=token.get("sub"),
|
||||||
azp=token.get("azp"),
|
azp=token.get("azp"),
|
||||||
@@ -220,6 +209,14 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_create_search_model():
|
||||||
|
search_model = SearchModelConfig.objects.filter().first()
|
||||||
|
if not search_model:
|
||||||
|
search_model = SearchModelConfig.objects.create()
|
||||||
|
|
||||||
|
return search_model
|
||||||
|
|
||||||
|
|
||||||
class ConversationAdapters:
|
class ConversationAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_conversation_by_user(user: KhojUser):
|
def get_conversation_by_user(user: KhojUser):
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from database.models import (
|
|||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
OfflineChatProcessorConversationConfig,
|
OfflineChatProcessorConversationConfig,
|
||||||
|
SearchModelConfig,
|
||||||
Subscription,
|
Subscription,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,4 +17,5 @@ admin.site.register(KhojUser, UserAdmin)
|
|||||||
admin.site.register(ChatModelOptions)
|
admin.site.register(ChatModelOptions)
|
||||||
admin.site.register(OpenAIProcessorConversationConfig)
|
admin.site.register(OpenAIProcessorConversationConfig)
|
||||||
admin.site.register(OfflineChatProcessorConversationConfig)
|
admin.site.register(OfflineChatProcessorConversationConfig)
|
||||||
|
admin.site.register(SearchModelConfig)
|
||||||
admin.site.register(Subscription)
|
admin.site.register(Subscription)
|
||||||
|
|||||||
32
src/database/migrations/0017_searchmodel.py
Normal file
32
src/database/migrations/0017_searchmodel.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# Generated by Django 4.2.5 on 2023-11-14 23:25
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0016_alter_subscription_renewal_date"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="SearchModel",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("name", models.CharField(default="default", max_length=200)),
|
||||||
|
("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
|
||||||
|
("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
|
||||||
|
(
|
||||||
|
"cross_encoder",
|
||||||
|
models.CharField(
|
||||||
|
blank=True, default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200, null=True
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
# Generated by Django 4.2.5 on 2023-11-16 01:13
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0017_searchmodel"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="SearchModelConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("name", models.CharField(default="default", max_length=200)),
|
||||||
|
("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
|
||||||
|
("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
|
||||||
|
("cross_encoder", models.CharField(default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.DeleteModel(
|
||||||
|
name="SearchModel",
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -102,6 +102,16 @@ class LocalPlaintextConfig(BaseModel):
|
|||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchModelConfig(BaseModel):
|
||||||
|
class ModelType(models.TextChoices):
|
||||||
|
TEXT = "text"
|
||||||
|
|
||||||
|
name = models.CharField(max_length=200, default="default")
|
||||||
|
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
|
||||||
|
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
|
||||||
|
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProcessorConversationConfig(BaseModel):
|
class OpenAIProcessorConversationConfig(BaseModel):
|
||||||
api_key = models.CharField(max_length=200)
|
api_key = models.CharField(max_length=200)
|
||||||
|
|
||||||
|
|||||||
@@ -328,7 +328,15 @@
|
|||||||
.then(data => {
|
.then(data => {
|
||||||
if (data.detail) {
|
if (data.detail) {
|
||||||
// If the server returns a 500 error with detail, render a setup hint.
|
// If the server returns a 500 error with detail, render a setup hint.
|
||||||
renderMessage("Hi 👋🏾, to get started you have two options:<ol><li><b>Use OpenAI</b>: <ol><li>Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a></li><li>Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a></li><li>Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a></li></ol></li><li><b>Enable offline chat</b>: <ol><li>Go to the Khoj <a class='inline-chat-link' href='/config'>settings page</a> and enable offline chat</li></ol></li></ol>", "khoj");
|
first_run_message = `Hi 👋🏾, to get started:
|
||||||
|
<ol>
|
||||||
|
<li>Generate an API token in the <a class='inline-chat-link' href="#" onclick="window.navigateAPI.navigateToWebSettings()">Khoj Web settings</a></li>
|
||||||
|
<li>Paste it into the API Key field in the <a class='inline-chat-link' href="#" onclick="window.navigateAPI.navigateToSettings()">Khoj Desktop settings</a></li>
|
||||||
|
</ol>`
|
||||||
|
.trim()
|
||||||
|
.replace(/(\r\n|\n|\r)/gm, "");
|
||||||
|
|
||||||
|
renderMessage(first_run_message, "khoj");
|
||||||
|
|
||||||
// Disable chat input field and update placeholder text
|
// Disable chat input field and update placeholder text
|
||||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||||
|
|||||||
@@ -396,6 +396,14 @@ app.whenReady().then(() => {
|
|||||||
event.reply('update-state', arg);
|
event.reply('update-state', arg);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
ipcMain.on('navigate', (event, page) => {
|
||||||
|
win.loadFile(page);
|
||||||
|
});
|
||||||
|
|
||||||
|
ipcMain.on('navigateToWebApp', (event, page) => {
|
||||||
|
shell.openExternal(`${store.get('hostURL')}/${page}`);
|
||||||
|
});
|
||||||
|
|
||||||
ipcMain.handle('getFiles', getFiles);
|
ipcMain.handle('getFiles', getFiles);
|
||||||
ipcMain.handle('getFolders', getFolders);
|
ipcMain.handle('getFolders', getFolders);
|
||||||
|
|
||||||
|
|||||||
@@ -10,14 +10,14 @@
|
|||||||
"main": "main.js",
|
"main": "main.js",
|
||||||
"private": false,
|
"private": false,
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"electron": "25.8.1"
|
"electron": "25.8.4"
|
||||||
},
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"start": "yarn electron ."
|
"start": "yarn electron ."
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@todesktop/runtime": "^1.3.0",
|
"@todesktop/runtime": "^1.3.0",
|
||||||
"axios": "^1.5.0",
|
"axios": "^1.6.0",
|
||||||
"cron": "^2.4.3",
|
"cron": "^2.4.3",
|
||||||
"electron-store": "^8.1.0",
|
"electron-store": "^8.1.0",
|
||||||
"fs": "^0.0.1-security"
|
"fs": "^0.0.1-security"
|
||||||
|
|||||||
@@ -57,3 +57,8 @@ contextBridge.exposeInMainWorld('tokenAPI', {
|
|||||||
contextBridge.exposeInMainWorld('appInfoAPI', {
|
contextBridge.exposeInMainWorld('appInfoAPI', {
|
||||||
getInfo: (callback) => ipcRenderer.on('appInfo', callback)
|
getInfo: (callback) => ipcRenderer.on('appInfo', callback)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
contextBridge.exposeInMainWorld('navigateAPI', {
|
||||||
|
navigateToSettings: () => ipcRenderer.send('navigate', 'config.html'),
|
||||||
|
navigateToWebSettings: () => ipcRenderer.send('navigateToWebApp', 'config'),
|
||||||
|
})
|
||||||
|
|||||||
@@ -163,10 +163,10 @@ atomically@^1.7.0:
|
|||||||
resolved "https://registry.yarnpkg.com/atomically/-/atomically-1.7.0.tgz#c07a0458432ea6dbc9a3506fffa424b48bccaafe"
|
resolved "https://registry.yarnpkg.com/atomically/-/atomically-1.7.0.tgz#c07a0458432ea6dbc9a3506fffa424b48bccaafe"
|
||||||
integrity sha512-Xcz9l0z7y9yQ9rdDaxlmaI4uJHf/T8g9hOEzJcsEqX2SjCj4J20uK7+ldkDHMbpJDK76wF7xEIgxc/vSlsfw5w==
|
integrity sha512-Xcz9l0z7y9yQ9rdDaxlmaI4uJHf/T8g9hOEzJcsEqX2SjCj4J20uK7+ldkDHMbpJDK76wF7xEIgxc/vSlsfw5w==
|
||||||
|
|
||||||
axios@^1.5.0:
|
axios@^1.6.0:
|
||||||
version "1.5.0"
|
version "1.6.2"
|
||||||
resolved "https://registry.yarnpkg.com/axios/-/axios-1.5.0.tgz#f02e4af823e2e46a9768cfc74691fdd0517ea267"
|
resolved "https://registry.yarnpkg.com/axios/-/axios-1.6.2.tgz#de67d42c755b571d3e698df1b6504cde9b0ee9f2"
|
||||||
integrity sha512-D4DdjDo5CY50Qms0qGQTTw6Q44jl7zRwY7bthds06pUGfChBCTcQs+N743eFWGEd6pRTMd6A+I87aWyFV5wiZQ==
|
integrity sha512-7i24Ri4pmDRfJTR7LDBhsOTtcm+9kjX5WiY1X3wIisx6G9So3pfMkEiU7emUBe46oceVImccTEM3k6C5dbVW8A==
|
||||||
dependencies:
|
dependencies:
|
||||||
follow-redirects "^1.15.0"
|
follow-redirects "^1.15.0"
|
||||||
form-data "^4.0.0"
|
form-data "^4.0.0"
|
||||||
@@ -379,10 +379,10 @@ electron-updater@^4.6.1:
|
|||||||
lodash.isequal "^4.5.0"
|
lodash.isequal "^4.5.0"
|
||||||
semver "^7.3.5"
|
semver "^7.3.5"
|
||||||
|
|
||||||
electron@25.8.1:
|
electron@25.8.4:
|
||||||
version "25.8.1"
|
version "25.8.4"
|
||||||
resolved "https://registry.yarnpkg.com/electron/-/electron-25.8.1.tgz#092fab5a833db4d9240d4d6f36218cf7ca954f86"
|
resolved "https://registry.yarnpkg.com/electron/-/electron-25.8.4.tgz#b50877aac7d96323920437baf309ad86382cb455"
|
||||||
integrity sha512-GtcP1nMrROZfFg0+mhyj1hamrHvukfF6of2B/pcWxmWkd5FVY1NJib0tlhiorFZRzQN5Z+APLPr7aMolt7i2AQ==
|
integrity sha512-hUYS3RGdaa6E1UWnzeGnsdsBYOggwMMg4WGxNGvAoWtmRrr6J1BsjFW/yRq4WsJHJce2HdzQXtz4OGXV6yUCLg==
|
||||||
dependencies:
|
dependencies:
|
||||||
"@electron/get" "^2.0.0"
|
"@electron/get" "^2.0.0"
|
||||||
"@types/node" "^18.11.18"
|
"@types/node" "^18.11.18"
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import logging
|
|||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import Request
|
|
||||||
import requests
|
import requests
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -21,15 +20,16 @@ from starlette.authentication import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
|
from database.models import KhojUser, Subscription
|
||||||
|
from database.adapters import get_all_users, get_or_create_search_model
|
||||||
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
|
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from khoj.utils.config import (
|
from khoj.utils.config import (
|
||||||
SearchType,
|
SearchType,
|
||||||
)
|
)
|
||||||
from khoj.utils.fs_syncer import collect_files
|
from khoj.utils.fs_syncer import collect_files
|
||||||
from khoj.utils.rawconfig import FullConfig
|
from khoj.utils.rawconfig import FullConfig
|
||||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
|
||||||
from database.models import KhojUser, Subscription
|
|
||||||
from database.adapters import get_all_users
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -113,14 +113,13 @@ def configure_server(
|
|||||||
|
|
||||||
# Initialize Search Models from Config and initialize content
|
# Initialize Search Models from Config and initialize content
|
||||||
try:
|
try:
|
||||||
state.config_lock.acquire()
|
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder)
|
||||||
|
state.SearchType = configure_search_types()
|
||||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
initialize_content(regenerate, search_type, init, user)
|
initialize_content(regenerate, search_type, init, user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
finally:
|
|
||||||
state.config_lock.release()
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
|
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
|
||||||
@@ -192,7 +191,7 @@ def update_search_index():
|
|||||||
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)
|
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def configure_search_types(config: FullConfig):
|
def configure_search_types():
|
||||||
# Extract core search types
|
# Extract core search types
|
||||||
core_search_types = {e.name: e.value for e in SearchType}
|
core_search_types = {e.name: e.value for e in SearchType}
|
||||||
|
|
||||||
|
|||||||
@@ -327,7 +327,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||||||
.then(data => {
|
.then(data => {
|
||||||
if (data.detail) {
|
if (data.detail) {
|
||||||
// If the server returns a 500 error with detail, render a setup hint.
|
// If the server returns a 500 error with detail, render a setup hint.
|
||||||
renderMessage("Hi 👋🏾, to get started you have two options:<ol><li><b>Use OpenAI</b>: <ol><li>Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a></li><li>Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a></li><li>Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a></li></ol></li><li><b>Enable offline chat</b>: <ol><li>Go to the Khoj <a class='inline-chat-link' href='/config'>settings page</a> and enable offline chat</li></ol></li></ol>", "khoj");
|
renderMessage("Hi 👋🏾, to start chatting add available chat models options via <a class='inline-chat-link' href='/server/admin'>the Django Admin panel</a> on the Server", "khoj");
|
||||||
|
|
||||||
// Disable chat input field and update placeholder text
|
// Disable chat input field and update placeholder text
|
||||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ search-type:
|
|||||||
encoder: sentence-transformers/all-MiniLM-L6-v2
|
encoder: sentence-transformers/all-MiniLM-L6-v2
|
||||||
encoder-type: null
|
encoder-type: null
|
||||||
model-directory: ~/.khoj/search/symmetric
|
model-directory: ~/.khoj/search/symmetric
|
||||||
version: 0.12.4
|
version: 0.14.0
|
||||||
|
|
||||||
|
|
||||||
The new version will looks like this:
|
The new version will looks like this:
|
||||||
@@ -53,11 +53,7 @@ search-type:
|
|||||||
asymmetric:
|
asymmetric:
|
||||||
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
|
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||||
encoder: sentence-transformers/multi-qa-MiniLM-L6-cos-v1
|
encoder: sentence-transformers/multi-qa-MiniLM-L6-cos-v1
|
||||||
image:
|
version: 0.15.0
|
||||||
encoder: sentence-transformers/clip-ViT-B-32
|
|
||||||
encoder-type: null
|
|
||||||
model-directory: /Users/si/.khoj/search/image
|
|
||||||
version: 0.12.4
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -68,6 +64,7 @@ from database.models import (
|
|||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
OfflineChatProcessorConversationConfig,
|
OfflineChatProcessorConversationConfig,
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
|
SearchModelConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -87,6 +84,19 @@ def migrate_server_pg(args):
|
|||||||
if raw_config is None:
|
if raw_config is None:
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
if "search-type" in raw_config and raw_config["search-type"]:
|
||||||
|
if "asymmetric" in raw_config["search-type"]:
|
||||||
|
# Delete all existing search models
|
||||||
|
SearchModelConfig.objects.filter(model_type=SearchModelConfig.ModelType.TEXT).delete()
|
||||||
|
# Create new search model from existing Khoj YAML config
|
||||||
|
asymmetric_search = raw_config["search-type"]["asymmetric"]
|
||||||
|
SearchModelConfig.objects.create(
|
||||||
|
name="default",
|
||||||
|
model_type=SearchModelConfig.ModelType.TEXT,
|
||||||
|
bi_encoder=asymmetric_search.get("encoder"),
|
||||||
|
cross_encoder=asymmetric_search.get("cross-encoder"),
|
||||||
|
)
|
||||||
|
|
||||||
if "processor" in raw_config and raw_config["processor"] and "conversation" in raw_config["processor"]:
|
if "processor" in raw_config and raw_config["processor"] and "conversation" in raw_config["processor"]:
|
||||||
processor_conversation = raw_config["processor"]["conversation"]
|
processor_conversation = raw_config["processor"]["conversation"]
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from khoj.utils.helpers import get_device
|
from khoj.utils.helpers import get_device
|
||||||
from khoj.utils.rawconfig import SearchResponse
|
from khoj.utils.rawconfig import SearchResponse
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsModel:
|
class EmbeddingsModel:
|
||||||
def __init__(self):
|
def __init__(self, model_name: str = "thenlper/gte-small"):
|
||||||
self.encode_kwargs = {"normalize_embeddings": True}
|
self.encode_kwargs = {"normalize_embeddings": True}
|
||||||
self.model_kwargs = {"device": get_device()}
|
self.model_kwargs = {"device": get_device()}
|
||||||
self.model_name = "thenlper/gte-small"
|
self.model_name = model_name
|
||||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||||
|
|
||||||
def embed_query(self, query):
|
def embed_query(self, query):
|
||||||
@@ -21,11 +22,11 @@ class EmbeddingsModel:
|
|||||||
|
|
||||||
|
|
||||||
class CrossEncoderModel:
|
class CrossEncoderModel:
|
||||||
def __init__(self):
|
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
||||||
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
self.model_name = model_name
|
||||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
||||||
|
|
||||||
def predict(self, query, hits: List[SearchResponse]):
|
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
||||||
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
|
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
||||||
cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
|
cross_scores = self.cross_encoder_model.predict(cross_inp, activation_fct=nn.Sigmoid())
|
||||||
return cross_scores
|
return cross_scores
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Callable, List, Tuple, Set, Any
|
from typing import Callable, List, Tuple, Set, Any
|
||||||
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import is_none_or_empty, timer, batcher
|
from khoj.utils.helpers import is_none_or_empty, timer, batcher
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.rawconfig import Entry
|
from khoj.utils.rawconfig import Entry
|
||||||
from khoj.processor.embeddings import EmbeddingsModel
|
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from database.models import KhojUser, Entry as DbEntry, EntryDates
|
from database.models import KhojUser, Entry as DbEntry, EntryDates
|
||||||
from database.adapters import EntryAdapters
|
from database.adapters import EntryAdapters
|
||||||
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class TextToEntries(ABC):
|
class TextToEntries(ABC):
|
||||||
def __init__(self, config: Any = None):
|
def __init__(self, config: Any = None):
|
||||||
self.embeddings_model = EmbeddingsModel()
|
self.embeddings_model = state.embeddings_model
|
||||||
self.config = config
|
self.config = config
|
||||||
self.date_filter = DateFilter()
|
self.date_filter = DateFilter()
|
||||||
|
|
||||||
|
|||||||
@@ -376,7 +376,7 @@ async def search(
|
|||||||
# initialize variables
|
# initialize variables
|
||||||
user_query = q.strip()
|
user_query = q.strip()
|
||||||
results_count = n or 5
|
results_count = n or 5
|
||||||
max_distance = max_distance if max_distance is not None else math.inf
|
max_distance = max_distance or math.inf
|
||||||
search_futures: List[concurrent.futures.Future] = []
|
search_futures: List[concurrent.futures.Future] = []
|
||||||
|
|
||||||
# return cached results, if available
|
# return cached results, if available
|
||||||
@@ -581,7 +581,7 @@ async def chat(
|
|||||||
request: Request,
|
request: Request,
|
||||||
q: str,
|
q: str,
|
||||||
n: Optional[int] = 5,
|
n: Optional[int] = 5,
|
||||||
d: Optional[float] = 0.15,
|
d: Optional[float] = 0.18,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
user_agent: Optional[str] = Header(None),
|
user_agent: Optional[str] = Header(None),
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from google.auth.transport import requests as google_requests
|
|||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token
|
from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token
|
||||||
|
from database.models import KhojApiUser
|
||||||
from khoj.routers.helpers import update_telemetry_state
|
from khoj.routers.helpers import update_telemetry_state
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
|
||||||
@@ -51,12 +52,16 @@ async def login(request: Request):
|
|||||||
|
|
||||||
@auth_router.post("/token")
|
@auth_router.post("/token")
|
||||||
@requires(["authenticated"], redirect="login_page")
|
@requires(["authenticated"], redirect="login_page")
|
||||||
async def generate_token(request: Request, token_name: Optional[str] = None) -> str:
|
async def generate_token(request: Request, token_name: Optional[str] = None):
|
||||||
"Generate API token for given user"
|
"Generate API token for given user"
|
||||||
if token_name:
|
if token_name:
|
||||||
return await create_khoj_token(user=request.user.object, name=token_name)
|
token = await create_khoj_token(user=request.user.object, name=token_name)
|
||||||
else:
|
else:
|
||||||
return await create_khoj_token(user=request.user.object)
|
token = await create_khoj_token(user=request.user.object)
|
||||||
|
return {
|
||||||
|
"token": token.token,
|
||||||
|
"name": token.name,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@auth_router.get("/token")
|
@auth_router.get("/token")
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Type, Union, Dict
|
from typing import List, Tuple, Type, Union
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -5,21 +5,20 @@ from typing import List, Dict
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import config as utils_config
|
from khoj.utils import config as utils_config
|
||||||
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import LRU, get_device
|
from khoj.utils.helpers import LRU, get_device
|
||||||
from khoj.utils.rawconfig import FullConfig
|
from khoj.utils.rawconfig import FullConfig
|
||||||
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
|
|
||||||
|
|
||||||
# Application Global State
|
# Application Global State
|
||||||
config = FullConfig()
|
config = FullConfig()
|
||||||
search_models = SearchModels()
|
search_models = SearchModels()
|
||||||
embeddings_model = EmbeddingsModel()
|
embeddings_model: EmbeddingsModel = None
|
||||||
cross_encoder_model = CrossEncoderModel()
|
cross_encoder_model: CrossEncoderModel = None
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||||
config_file: Path = None
|
config_file: Path = None
|
||||||
@@ -28,7 +27,6 @@ host: str = None
|
|||||||
port: int = None
|
port: int = None
|
||||||
cli_args: List[str] = None
|
cli_args: List[str] = None
|
||||||
query_cache: Dict[str, LRU] = defaultdict(LRU)
|
query_cache: Dict[str, LRU] = defaultdict(LRU)
|
||||||
config_lock = threading.Lock()
|
|
||||||
chat_lock = threading.Lock()
|
chat_lock = threading.Lock()
|
||||||
SearchType = utils_config.SearchType
|
SearchType = utils_config.SearchType
|
||||||
telemetry: List[Dict[str, str]] = []
|
telemetry: List[Dict[str, str]] = []
|
||||||
|
|||||||
@@ -8,11 +8,13 @@ from fastapi import FastAPI
|
|||||||
import os
|
import os
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_routes, configure_search_types, configure_middleware
|
from khoj.configure import configure_routes, configure_search_types, configure_middleware
|
||||||
|
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||||
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
|
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.utils.config import SearchModels
|
from khoj.utils.config import SearchModels
|
||||||
@@ -54,6 +56,9 @@ def enable_db_access_for_all_tests(db):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def search_config() -> SearchConfig:
|
def search_config() -> SearchConfig:
|
||||||
|
state.embeddings_model = EmbeddingsModel()
|
||||||
|
state.cross_encoder_model = CrossEncoderModel()
|
||||||
|
|
||||||
model_dir = resolve_absolute_path("~/.khoj/search")
|
model_dir = resolve_absolute_path("~/.khoj/search")
|
||||||
model_dir.mkdir(parents=True, exist_ok=True)
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
search_config = SearchConfig()
|
search_config = SearchConfig()
|
||||||
@@ -222,7 +227,7 @@ def md_content_config():
|
|||||||
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
||||||
# Initialize app state
|
# Initialize app state
|
||||||
state.config.search_type = search_config
|
state.config.search_type = search_config
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types()
|
||||||
|
|
||||||
LocalMarkdownConfig.objects.create(
|
LocalMarkdownConfig.objects.create(
|
||||||
input_files=None,
|
input_files=None,
|
||||||
@@ -256,7 +261,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
|||||||
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
|
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
|
||||||
# Initialize app state
|
# Initialize app state
|
||||||
state.config.search_type = search_config
|
state.config.search_type = search_config
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types()
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
if os.getenv("OPENAI_API_KEY"):
|
if os.getenv("OPENAI_API_KEY"):
|
||||||
@@ -291,7 +296,9 @@ def client(
|
|||||||
):
|
):
|
||||||
state.config.content_type = content_config
|
state.config.content_type = content_config
|
||||||
state.config.search_type = search_config
|
state.config.search_type = search_config
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types()
|
||||||
|
state.embeddings_model = EmbeddingsModel()
|
||||||
|
state.cross_encoder_model = CrossEncoderModel()
|
||||||
|
|
||||||
# These lines help us Mock the Search models for these search types
|
# These lines help us Mock the Search models for these search types
|
||||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
@@ -323,7 +330,7 @@ def client(
|
|||||||
def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
|
def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
|
||||||
# Initialize app state
|
# Initialize app state
|
||||||
state.config.search_type = search_config
|
state.config.search_type = search_config
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types()
|
||||||
|
|
||||||
LocalMarkdownConfig.objects.create(
|
LocalMarkdownConfig.objects.create(
|
||||||
input_files=None,
|
input_files=None,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from database.models import (
|
|||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
OfflineChatProcessorConversationConfig,
|
OfflineChatProcessorConversationConfig,
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
|
SearchModelConfig,
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
Conversation,
|
Conversation,
|
||||||
Subscription,
|
Subscription,
|
||||||
@@ -71,6 +72,16 @@ class ConversationFactory(factory.django.DjangoModelFactory):
|
|||||||
user = factory.SubFactory(UserFactory)
|
user = factory.SubFactory(UserFactory)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchModelFactory(factory.django.DjangoModelFactory):
|
||||||
|
class Meta:
|
||||||
|
model = SearchModelConfig
|
||||||
|
|
||||||
|
name = "default"
|
||||||
|
model_type = "text"
|
||||||
|
bi_encoder = "thenlper/gte-small"
|
||||||
|
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Subscription
|
model = Subscription
|
||||||
|
|||||||
@@ -173,7 +173,6 @@ def test_regenerate_with_github_fails_without_pat(client):
|
|||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@pytest.mark.skip(reason="Flaky test on parallel test runs")
|
|
||||||
def test_get_configured_types_via_api(client, sample_org_data):
|
def test_get_configured_types_via_api(client, sample_org_data):
|
||||||
# Act
|
# Act
|
||||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False)
|
text_search.setup(OrgToEntries, sample_org_data, regenerate=False)
|
||||||
@@ -203,10 +202,10 @@ def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
|
|||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
||||||
# Arrange
|
# Arrange
|
||||||
state.SearchType = configure_search_types(config)
|
|
||||||
original_config = state.config.content_type
|
|
||||||
state.config.content_type = None
|
|
||||||
state.anonymous_mode = True
|
state.anonymous_mode = True
|
||||||
|
if state.config and state.config.content_type:
|
||||||
|
state.config.content_type = None
|
||||||
|
state.search_models = configure_search_types()
|
||||||
|
|
||||||
configure_routes(fastapi_app)
|
configure_routes(fastapi_app)
|
||||||
client = TestClient(fastapi_app)
|
client = TestClient(fastapi_app)
|
||||||
@@ -218,9 +217,6 @@ def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == ["all"]
|
assert response.json() == ["all"]
|
||||||
|
|
||||||
# Restore
|
|
||||||
state.config.content_type = original_config
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
@@ -259,13 +255,30 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa
|
|||||||
user_query = quote("How to git install application?")
|
user_query = quote("How to git install application?")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true", headers=headers)
|
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
# assert actual_data contains "Khoj via Emacs" entry
|
|
||||||
|
assert len(response.json()) == 1, "Expected only 1 result"
|
||||||
search_result = response.json()[0]["entry"]
|
search_result = response.json()[0]["entry"]
|
||||||
assert "git clone https://github.com/khoj-ai/khoj" in search_result
|
assert "git clone https://github.com/khoj-ai/khoj" in search_result, "Expected 'git clone' in search result"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_notes_search_no_results(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
|
||||||
|
# Arrange
|
||||||
|
headers = {"Authorization": "Bearer kk-secret"}
|
||||||
|
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||||
|
user_query = quote("How to find my goat?")
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == [], "Expected no results"
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user