mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
[Multi-User Part 1]: Enable storage of settings for plaintext files based on user account (#498)
- Partition configuration for indexing local data based on user accounts - Store indexed data in an underlying postgres db using the `pgvector` extension - Add migrations for all relevant user data and embeddings generation. Very little performance optimization has been done for the lookup time - Apply filters using SQL queries - Start removing many server-level configuration settings - Configure GitHub test actions to run during any PR. Update the test action to run in a containerized environment with a DB. - Update the Docker image and docker-compose.yml to work with the new application design
This commit is contained in:
@@ -1,15 +1,19 @@
|
||||
# External Packages
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from fastapi.testclient import TestClient
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi import FastAPI
|
||||
import factory
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from app.main import app
|
||||
from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils.config import SearchModels
|
||||
@@ -22,8 +26,6 @@ from khoj.utils.rawconfig import (
|
||||
OpenAIProcessorConfig,
|
||||
ProcessorConfig,
|
||||
TextContentConfig,
|
||||
GithubContentConfig,
|
||||
GithubRepoConfig,
|
||||
ImageContentConfig,
|
||||
SearchConfig,
|
||||
TextSearchConfig,
|
||||
@@ -31,11 +33,31 @@ from khoj.utils.rawconfig import (
|
||||
)
|
||||
from khoj.utils import state, fs_syncer
|
||||
from khoj.routers.indexer import configure_content
|
||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from database.models import (
|
||||
LocalOrgConfig,
|
||||
LocalMarkdownConfig,
|
||||
LocalPlaintextConfig,
|
||||
LocalPdfConfig,
|
||||
GithubConfig,
|
||||
KhojUser,
|
||||
GithubRepoConfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_db_access_for_all_tests(db):
|
||||
pass
|
||||
|
||||
|
||||
class UserFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = KhojUser
|
||||
|
||||
username = factory.Faker("name")
|
||||
email = factory.Faker("email")
|
||||
password = factory.Faker("password")
|
||||
uuid = factory.Faker("uuid4")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -67,17 +89,28 @@ def search_config() -> SearchConfig:
|
||||
return search_config
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_user():
|
||||
return UserFactory()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_models(search_config: SearchConfig):
|
||||
search_models = SearchModels()
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
return search_models
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
|
||||
@pytest.fixture
|
||||
def anyio_backend():
|
||||
return "asyncio"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture(scope="function")
|
||||
def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser):
|
||||
content_dir = tmp_path_factory.mktemp("content")
|
||||
|
||||
# Generate Image Embeddings from Test Images
|
||||
@@ -92,94 +125,45 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
|
||||
|
||||
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
|
||||
|
||||
# Generate Notes Embeddings from Test Notes
|
||||
content_config.org = TextContentConfig(
|
||||
LocalOrgConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/org/*.org"],
|
||||
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
|
||||
index_heading_entries=False,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
get_sample_data("org"),
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
content_config.plugins = {
|
||||
"plugin1": TextContentConfig(
|
||||
input_files=[content_dir.joinpath("notes.jsonl.gz")],
|
||||
input_filter=None,
|
||||
compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
|
||||
)
|
||||
}
|
||||
text_search.setup(OrgToJsonl, get_sample_data("org"), regenerate=False, user=default_user)
|
||||
|
||||
if os.getenv("GITHUB_PAT_TOKEN"):
|
||||
content_config.github = GithubContentConfig(
|
||||
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
|
||||
repos=[
|
||||
GithubRepoConfig(
|
||||
owner="khoj-ai",
|
||||
name="lantern",
|
||||
branch="master",
|
||||
)
|
||||
],
|
||||
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
|
||||
GithubConfig.objects.create(
|
||||
pat_token=os.getenv("GITHUB_PAT_TOKEN"),
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
content_config.plaintext = TextContentConfig(
|
||||
GithubRepoConfig.objects.create(
|
||||
owner="khoj-ai",
|
||||
name="lantern",
|
||||
branch="master",
|
||||
github_config=GithubConfig.objects.get(user=default_user),
|
||||
)
|
||||
|
||||
LocalPlaintextConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
|
||||
compressed_jsonl=content_dir.joinpath("plaintext.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("plaintext_embeddings.pt"),
|
||||
)
|
||||
|
||||
content_config.github = GithubContentConfig(
|
||||
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
|
||||
repos=[
|
||||
GithubRepoConfig(
|
||||
owner="khoj-ai",
|
||||
name="lantern",
|
||||
branch="master",
|
||||
)
|
||||
],
|
||||
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
|
||||
)
|
||||
|
||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
text_search.setup(
|
||||
JsonlToJsonl,
|
||||
None,
|
||||
content_config.plugins["plugin1"],
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
return content_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def md_content_config(tmp_path_factory):
|
||||
content_dir = tmp_path_factory.mktemp("content")
|
||||
|
||||
# Generate Embeddings for Markdown Content
|
||||
content_config = ContentConfig()
|
||||
content_config.markdown = TextContentConfig(
|
||||
def md_content_config():
|
||||
markdown_config = LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/markdown/*.markdown"],
|
||||
compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
|
||||
)
|
||||
|
||||
return content_config
|
||||
return markdown_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -220,19 +204,20 @@ def processor_config_offline_chat(tmp_path_factory):
|
||||
@pytest.fixture(scope="session")
|
||||
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
|
||||
# Initialize app state
|
||||
state.config.content_type = md_content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
|
||||
# Index Markdown Content for Search
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
all_files = fs_syncer.collect_files(state.config.content_type)
|
||||
all_files = fs_syncer.collect_files()
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models
|
||||
)
|
||||
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(processor_config)
|
||||
state.anonymous_mode = True
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
@@ -241,33 +226,45 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
|
||||
def fastapi_app():
|
||||
app = FastAPI()
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(
|
||||
content_config: ContentConfig,
|
||||
search_config: SearchConfig,
|
||||
processor_config: ProcessorConfig,
|
||||
default_user: KhojUser,
|
||||
):
|
||||
state.config.content_type = content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
|
||||
# These lines help us Mock the Search models for these search types
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
state.content_index.org = text_search.setup(
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
get_sample_data("org"),
|
||||
content_config.org,
|
||||
state.search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
user=default_user,
|
||||
)
|
||||
state.content_index.image = image_search.setup(
|
||||
content_config.image, state.search_models.image_search, regenerate=False
|
||||
)
|
||||
state.content_index.plaintext = text_search.setup(
|
||||
text_search.setup(
|
||||
PlaintextToJsonl,
|
||||
get_sample_data("plaintext"),
|
||||
content_config.plaintext,
|
||||
state.search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
state.processor_config = configure_processor(processor_config)
|
||||
state.anonymous_mode = True
|
||||
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
@@ -288,7 +285,6 @@ def client_offline_chat(
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
|
||||
# Index Markdown Content for Search
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
all_files = fs_syncer.collect_files(state.config.content_type)
|
||||
@@ -298,6 +294,7 @@ def client_offline_chat(
|
||||
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(processor_config_offline_chat)
|
||||
state.anonymous_mode = True
|
||||
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
@@ -306,9 +303,11 @@ def client_offline_chat(
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def new_org_file(content_config: ContentConfig):
|
||||
def new_org_file(default_user: KhojUser, content_config: ContentConfig):
|
||||
# Setup
|
||||
new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org"
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
input_filters = org_config.input_filter
|
||||
new_org_file = Path(input_filters[0]).parent / "new_file.org"
|
||||
new_org_file.touch()
|
||||
|
||||
yield new_org_file
|
||||
@@ -319,11 +318,9 @@ def new_org_file(content_config: ContentConfig):
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path):
|
||||
new_org_config = deepcopy(content_config.org)
|
||||
new_org_config.input_files = [f"{new_org_file}"]
|
||||
new_org_config.input_filter = None
|
||||
return new_org_config
|
||||
def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
|
||||
LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None)
|
||||
return LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
||||
11
tests/data/config.yml
vendored
11
tests/data/config.yml
vendored
@@ -9,17 +9,6 @@ content-type:
|
||||
input-filter:
|
||||
- '*.org'
|
||||
- ~/notes/*.org
|
||||
plugins:
|
||||
content_plugin_1:
|
||||
compressed-jsonl: content_plugin_1.jsonl.gz
|
||||
embeddings-file: content_plugin_1_embeddings.pt
|
||||
input-files:
|
||||
- content_plugin_1_new.jsonl.gz
|
||||
content_plugin_2:
|
||||
compressed-jsonl: content_plugin_2.jsonl.gz
|
||||
embeddings-file: content_plugin_2_embeddings.pt
|
||||
input-filter:
|
||||
- '*2_new.jsonl.gz'
|
||||
enable-offline-chat: false
|
||||
search-type:
|
||||
asymmetric:
|
||||
|
||||
@@ -48,14 +48,3 @@ def test_cli_config_from_file():
|
||||
Path("~/first_from_config.org"),
|
||||
Path("~/second_from_config.org"),
|
||||
]
|
||||
assert len(actual_args.config.content_type.plugins.keys()) == 2
|
||||
assert actual_args.config.content_type.plugins["content_plugin_1"].input_files == [
|
||||
Path("content_plugin_1_new.jsonl.gz")
|
||||
]
|
||||
assert actual_args.config.content_type.plugins["content_plugin_2"].input_filter == ["*2_new.jsonl.gz"]
|
||||
assert actual_args.config.content_type.plugins["content_plugin_1"].compressed_jsonl == Path(
|
||||
"content_plugin_1.jsonl.gz"
|
||||
)
|
||||
assert actual_args.config.content_type.plugins["content_plugin_2"].embeddings_file == Path(
|
||||
"content_plugin_2_embeddings.pt"
|
||||
)
|
||||
|
||||
@@ -2,22 +2,21 @@
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from urllib.parse import quote
|
||||
|
||||
import pytest
|
||||
|
||||
# External Packages
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
# Internal Packages
|
||||
from app.main import app
|
||||
from khoj.configure import configure_routes, configure_search_types
|
||||
from khoj.utils import state
|
||||
from khoj.utils.state import search_models, content_index, config
|
||||
from khoj.search_type import text_search, image_search
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from database.models import KhojUser
|
||||
from database.adapters import EmbeddingsAdapters
|
||||
|
||||
|
||||
# Test
|
||||
@@ -35,7 +34,7 @@ def test_search_with_invalid_content_type(client):
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_search_with_valid_content_type(client):
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plugin1"]:
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion"]:
|
||||
# Act
|
||||
response = client.get(f"/api/search?q=random&t={content_type}")
|
||||
# Assert
|
||||
@@ -75,7 +74,7 @@ def test_index_update(client):
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_regenerate_with_valid_content_type(client):
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]:
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]:
|
||||
# Arrange
|
||||
files = get_sample_files_data()
|
||||
headers = {"x-api-key": "secret"}
|
||||
@@ -102,60 +101,42 @@ def test_regenerate_with_github_fails_without_pat(client):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.skip(reason="Flaky test on parallel test runs")
|
||||
def test_get_configured_types_via_api(client):
|
||||
def test_get_configured_types_via_api(client, sample_org_data):
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||
|
||||
enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["all", "org", "image", "plaintext", "plugin1"]
|
||||
assert list(enabled_types) == ["org"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_get_configured_types_with_only_plugin_content_config(content_config):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data):
|
||||
# Arrange
|
||||
config.content_type = ContentConfig()
|
||||
config.content_type.plugins = content_config.plugins
|
||||
state.SearchType = configure_search_types(config)
|
||||
|
||||
configure_routes(app)
|
||||
client = TestClient(app)
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["all", "plugin1"]
|
||||
assert response.json() == ["all", "org", "markdown", "image"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_get_configured_types_with_no_plugin_content_config(content_config):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
||||
# Arrange
|
||||
config.content_type = content_config
|
||||
config.content_type.plugins = None
|
||||
state.SearchType = configure_search_types(config)
|
||||
original_config = state.config.content_type
|
||||
state.config.content_type = None
|
||||
|
||||
configure_routes(app)
|
||||
client = TestClient(app)
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "plugin1" not in response.json()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_get_configured_types_with_no_content_config():
|
||||
# Arrange
|
||||
config.content_type = ContentConfig()
|
||||
state.SearchType = configure_search_types(config)
|
||||
|
||||
configure_routes(app)
|
||||
client = TestClient(app)
|
||||
configure_routes(fastapi_app)
|
||||
client = TestClient(fastapi_app)
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
@@ -164,6 +145,9 @@ def test_get_configured_types_with_no_content_config():
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["all"]
|
||||
|
||||
# Restore
|
||||
state.config.content_type = original_config
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
@@ -192,12 +176,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search(client, search_config: SearchConfig, sample_org_data):
|
||||
# Arrange
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
@@ -211,19 +193,15 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_only_filters(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
):
|
||||
# Arrange
|
||||
filters = [WordFilter(), FileFilter()]
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
sample_org_data,
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
)
|
||||
user_query = quote('+"Emacs" file:"*.org"')
|
||||
|
||||
@@ -238,15 +216,10 @@ def test_notes_search_with_only_filters(
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_include_filter(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_include_filter(client, sample_org_data):
|
||||
# Arrange
|
||||
filters = [WordFilter()]
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search, regenerate=False, filters=filters
|
||||
)
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
|
||||
user_query = quote('How to git install application? +"Emacs"')
|
||||
|
||||
# Act
|
||||
@@ -260,19 +233,13 @@ def test_notes_search_with_include_filter(
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_exclude_filter(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_exclude_filter(client, sample_org_data):
|
||||
# Arrange
|
||||
filters = [WordFilter()]
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
sample_org_data,
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
)
|
||||
user_query = quote('How to git install application? -"clone"')
|
||||
|
||||
@@ -286,6 +253,22 @@ def test_notes_search_with_exclude_filter(
|
||||
assert "clone" not in search_result
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# assert actual response has no data as the default_user is different from the user making the query (anonymous)
|
||||
assert len(response.json()) == 0
|
||||
|
||||
|
||||
def get_sample_files_data():
|
||||
return {
|
||||
"files": ("path/to/filename.org", "* practicing piano", "text/org"),
|
||||
|
||||
@@ -1,53 +1,12 @@
|
||||
# Standard Packages
|
||||
import re
|
||||
from datetime import datetime
|
||||
from math import inf
|
||||
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
|
||||
def test_date_filter():
|
||||
entries = [
|
||||
Entry(compiled="Entry with no date", raw="Entry with no date"),
|
||||
Entry(compiled="April Fools entry: 1984-04-01", raw="April Fools entry: 1984-04-01"),
|
||||
Entry(compiled="Entry with date:1984-04-02", raw="Entry with date:1984-04-02"),
|
||||
]
|
||||
|
||||
q_with_no_date_filter = "head tail"
|
||||
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2}
|
||||
|
||||
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == set()
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {2}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {1}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {2}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {1, 2}
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
|
||||
@@ -56,8 +15,8 @@ def test_extract_date_range():
|
||||
datetime(1984, 1, 5, 0, 0, 0).timestamp(),
|
||||
datetime(1984, 1, 7, 0, 0, 0).timestamp(),
|
||||
]
|
||||
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
|
||||
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [None, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), None]
|
||||
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
|
||||
datetime(1984, 1, 1, 0, 0, 0).timestamp(),
|
||||
datetime(1984, 1, 2, 0, 0, 0).timestamp(),
|
||||
|
||||
@@ -6,97 +6,73 @@ from khoj.utils.rawconfig import Entry
|
||||
def test_no_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = "head tail"
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == False
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_file_filter_with_non_existent_file():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"nonexistent.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {}
|
||||
|
||||
|
||||
def test_single_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"file 1.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_file_filter_with_partial_match():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"1.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_file_filter_with_regex_match():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"*.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_multiple_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_get_file_filter_terms():
|
||||
@@ -108,7 +84,7 @@ def test_get_file_filter_terms():
|
||||
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
|
||||
|
||||
# Assert
|
||||
assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"']
|
||||
assert filter_terms == ["file 1\\.org", "/path/to/dir/.*\\.org"]
|
||||
|
||||
|
||||
def arrange_content():
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
# Internal Packages
|
||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
def test_process_entries_from_single_input_jsonl(tmp_path):
|
||||
"Convert multiple jsonl entries from single file to entries."
|
||||
# Arrange
|
||||
input_jsonl = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}
|
||||
{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}
|
||||
"""
|
||||
input_jsonl_file = create_file(tmp_path, input_jsonl)
|
||||
|
||||
# Act
|
||||
# Process Each Entry from All Notes Files
|
||||
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file])
|
||||
entries = list(map(Entry.from_dict, input_jsons))
|
||||
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert output_jsonl == input_jsonl
|
||||
|
||||
|
||||
def test_process_entries_from_multiple_input_jsonls(tmp_path):
|
||||
"Convert multiple jsonl entries from single file to entries."
|
||||
# Arrange
|
||||
input_jsonl_1 = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}"""
|
||||
input_jsonl_2 = """{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}"""
|
||||
input_jsonl_file_1 = create_file(tmp_path, input_jsonl_1, filename="input1.jsonl")
|
||||
input_jsonl_file_2 = create_file(tmp_path, input_jsonl_2, filename="input2.jsonl")
|
||||
|
||||
# Act
|
||||
# Process Each Entry from All Notes Files
|
||||
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file_1, input_jsonl_file_2])
|
||||
entries = list(map(Entry.from_dict, input_jsons))
|
||||
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert output_jsonl == f"{input_jsonl_1}\n{input_jsonl_2}\n"
|
||||
|
||||
|
||||
def test_get_jsonl_files(tmp_path):
|
||||
"Ensure JSONL files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
# Include via input-filter globs
|
||||
group1_file1 = create_file(tmp_path, filename="group1-file1.jsonl")
|
||||
group1_file2 = create_file(tmp_path, filename="group1-file2.jsonl")
|
||||
group2_file1 = create_file(tmp_path, filename="group2-file1.jsonl")
|
||||
group2_file2 = create_file(tmp_path, filename="group2-file2.jsonl")
|
||||
# Include via input-file field
|
||||
file1 = create_file(tmp_path, filename="notes.jsonl")
|
||||
# Not included by any filter
|
||||
create_file(tmp_path, filename="not-included-jsonl.jsonl")
|
||||
create_file(tmp_path, filename="not-included-text.txt")
|
||||
|
||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / "notes.jsonl"]
|
||||
input_filter = [tmp_path / "group1*.jsonl", tmp_path / "group2*.jsonl"]
|
||||
|
||||
# Act
|
||||
extracted_org_files = JsonlToJsonl.get_jsonl_files(input_files, input_filter)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_org_files) == 5
|
||||
assert extracted_org_files == expected_files
|
||||
|
||||
|
||||
# Helper Functions
|
||||
def create_file(tmp_path, entry=None, filename="test.jsonl"):
|
||||
jsonl_file = tmp_path / filename
|
||||
jsonl_file.touch()
|
||||
if entry:
|
||||
jsonl_file.write_text(entry)
|
||||
return jsonl_file
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.utils.fs_syncer import get_org_files
|
||||
@@ -63,7 +63,7 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
|
||||
|
||||
# Split each entry from specified Org files by max words
|
||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||
TextToJsonl.split_entries_by_max_tokens(
|
||||
TextEmbeddings.split_entries_by_max_tokens(
|
||||
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
|
||||
)
|
||||
)
|
||||
@@ -86,7 +86,7 @@ def test_entry_split_drops_large_words():
|
||||
|
||||
# Act
|
||||
# Split entry by max words and drop words larger than max word length
|
||||
processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0]
|
||||
processed_entry = TextEmbeddings.split_entries_by_max_tokens([entry], max_word_length=5)[0]
|
||||
|
||||
# Assert
|
||||
# "Heading" dropped from compiled version because its over the set max word limit
|
||||
|
||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
from khoj.utils.fs_syncer import get_plaintext_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||
from database.models import LocalPlaintextConfig, KhojUser
|
||||
|
||||
|
||||
def test_plaintext_file(tmp_path):
|
||||
@@ -91,11 +92,12 @@ def test_get_plaintext_files(tmp_path):
|
||||
assert set(extracted_plaintext_files.keys()) == set(expected_files)
|
||||
|
||||
|
||||
def test_parse_html_plaintext_file(content_config):
|
||||
def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
|
||||
"Ensure HTML files are parsed correctly"
|
||||
# Arrange
|
||||
# Setup input-files, input-filters
|
||||
extracted_plaintext_files = get_plaintext_files(content_config.plaintext)
|
||||
config = LocalPlaintextConfig.objects.filter(user=default_user).first()
|
||||
extracted_plaintext_files = get_plaintext_files(config=config)
|
||||
|
||||
# Act
|
||||
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files)
|
||||
|
||||
@@ -3,23 +3,30 @@ import logging
|
||||
import locale
|
||||
from pathlib import Path
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.state import content_index, search_models
|
||||
from khoj.search_type import text_search
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.fs_syncer import get_org_files
|
||||
from khoj.utils.fs_syncer import get_org_files, collect_files
|
||||
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
|
||||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig):
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_missing_file_raises_error(
|
||||
org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig
|
||||
):
|
||||
# Arrange
|
||||
# Ensure file mentioned in org.input-files is missing
|
||||
single_new_file = Path(org_config_with_only_new_file.input_files[0])
|
||||
@@ -32,98 +39,126 @@ def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_n
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path: Path):
|
||||
@pytest.mark.django_db
|
||||
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser):
|
||||
# Arrange
|
||||
orgfile = tmp_path / "directory.org" / "file.org"
|
||||
orgfile.parent.mkdir()
|
||||
with open(orgfile, "w") as f:
|
||||
f.write("* Heading\n- List item\n")
|
||||
org_content_config = TextContentConfig(
|
||||
input_filter=[f"{tmp_path}/**/*"], compressed_jsonl="test.jsonl", embeddings_file="test.pt"
|
||||
|
||||
LocalOrgConfig.objects.create(
|
||||
input_filter=[f"{tmp_path}/**/*"],
|
||||
input_files=None,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
org_files = collect_files(user=default_user)["org"]
|
||||
|
||||
# Act
|
||||
# should not raise IsADirectoryError and return orgfile
|
||||
assert get_org_files(org_content_config) == {f"{orgfile}": "* Heading\n- List item\n"}
|
||||
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_empty_file_raises_error(
|
||||
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with pytest.raises(ValueError, match=r"^No valid entries found*"):
|
||||
text_search.setup(OrgToJsonl, data, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
verify_embeddings(0, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup(content_config, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
|
||||
# Act
|
||||
# Regenerate notes embeddings during asymmetric setup
|
||||
notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert len(notes_model.entries) == 10
|
||||
assert len(notes_model.corpus_embeddings) == 10
|
||||
assert "Deleting all embeddings for file type org" in caplog.records[1].message
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog):
|
||||
@pytest.mark.django_db
|
||||
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
caplog.set_level(logging.INFO, logger="khoj")
|
||||
|
||||
data = get_org_files(content_config.org)
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# Generate initial notes embeddings during asymmetric setup
|
||||
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
|
||||
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
assert "Creating index from scratch." in initial_logs
|
||||
assert "Creating index from scratch." not in final_logs
|
||||
assert "Deleting all embeddings for file type org" in initial_logs
|
||||
assert "Deleting all embeddings for file type org" not in final_logs
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.anyio
|
||||
async def test_text_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# @pytest.mark.asyncio
|
||||
async def test_text_search(search_config: SearchConfig):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
default_user = await KhojUser.objects.acreate(
|
||||
username="test_user", password="test_password", email="test@example.com"
|
||||
)
|
||||
# Arrange
|
||||
org_config = await LocalOrgConfig.objects.acreate(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/org/*.org"],
|
||||
index_heading_entries=False,
|
||||
user=default_user,
|
||||
)
|
||||
data = get_org_files(org_config)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
text_search.setup,
|
||||
OrgToJsonl,
|
||||
data,
|
||||
True,
|
||||
True,
|
||||
default_user,
|
||||
)
|
||||
|
||||
query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
hits, entries = await text_search.query(
|
||||
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
|
||||
)
|
||||
|
||||
results = text_search.collate_results(hits, entries, count=1)
|
||||
hits = await text_search.query(default_user, query)
|
||||
|
||||
# Assert
|
||||
results = text_search.collate_results(hits)
|
||||
results = sorted(results, key=lambda x: float(x.score))[:1]
|
||||
# search results should contain "git clone" entry
|
||||
search_result = results[0].entry
|
||||
assert "git clone" in search_result
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
||||
@pytest.mark.django_db
|
||||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||
max_tokens = 256
|
||||
@@ -137,47 +172,46 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
|
||||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model after adding new org-mode file
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
|
||||
# Assert
|
||||
# verify newly added org-mode entry is split by max tokens
|
||||
assert len(initial_notes_model.entries) == 2
|
||||
assert len(initial_notes_model.corpus_embeddings) == 2
|
||||
record = caplog.records[1]
|
||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
|
||||
@pytest.mark.django_db
|
||||
def test_entry_chunking_by_max_tokens_not_full_corpus(
|
||||
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||
data = {
|
||||
"readme.org": """
|
||||
* Khoj
|
||||
/Allow natural language search on user content like notes, images using transformer based models/
|
||||
/Allow natural language search on user content like notes, images using transformer based models/
|
||||
|
||||
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
|
||||
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
|
||||
|
||||
** Dependencies
|
||||
- Python3
|
||||
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
|
||||
- Python3
|
||||
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
|
||||
|
||||
** Install
|
||||
#+begin_src shell
|
||||
git clone https://github.com/khoj-ai/khoj && cd khoj
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
#+end_src"""
|
||||
#+begin_src shell
|
||||
git clone https://github.com/khoj-ai/khoj && cd khoj
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
#+end_src"""
|
||||
}
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
data,
|
||||
org_config_with_only_new_file,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
max_tokens = 256
|
||||
@@ -191,64 +225,57 @@ def test_entry_chunking_by_max_tokens_not_full_corpus(
|
||||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model after adding new org-mode file
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl,
|
||||
data,
|
||||
org_config_with_only_new_file,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
full_corpus=False,
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
data,
|
||||
regenerate=False,
|
||||
full_corpus=False,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
record = caplog.records[1]
|
||||
|
||||
# Assert
|
||||
# verify newly added org-mode entry is split by max tokens
|
||||
assert len(initial_notes_model.entries) == 5
|
||||
assert len(initial_notes_model.corpus_embeddings) == 5
|
||||
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_regenerate_index_with_new_entry(
|
||||
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
|
||||
content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
assert len(initial_notes_model.entries) == 10
|
||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
content_config.org.input_files = [f"{new_org_file}"]
|
||||
org_config.input_files = [f"{new_org_file}"]
|
||||
with open(new_org_file, "w") as f:
|
||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||
|
||||
data = get_org_files(content_config.org)
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
regenerated_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert len(regenerated_notes_model.entries) == 11
|
||||
assert len(regenerated_notes_model.corpus_embeddings) == 11
|
||||
|
||||
# verify new entry appended to index, without disrupting order or content of existing entries
|
||||
error_details = compare_index(initial_notes_model, regenerated_notes_model)
|
||||
if error_details:
|
||||
pytest.fail(error_details, False)
|
||||
|
||||
# Cleanup
|
||||
# reset input_files in config to empty list
|
||||
content_config.org.input_files = []
|
||||
assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message
|
||||
verify_embeddings(11, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_duplicate_entries_in_stable_order(
|
||||
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
@@ -262,30 +289,26 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
||||
|
||||
# Act
|
||||
# load embeddings, entries, notes model after adding new org-mode file
|
||||
initial_index = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# update embeddings, entries, notes model after adding new org-mode file
|
||||
updated_index = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert len(initial_index.entries) == len(updated_index.entries) == 1
|
||||
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1
|
||||
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
|
||||
|
||||
# verify the same entry is added even when there are multiple duplicate entries
|
||||
error_details = compare_index(initial_index, updated_index)
|
||||
if error_details:
|
||||
pytest.fail(error_details)
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
||||
@@ -296,9 +319,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
initial_index = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
# update embeddings, entries, notes model after removing an entry from the org file
|
||||
with open(new_file_to_index, "w") as f:
|
||||
@@ -307,87 +329,65 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
updated_index = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert len(initial_index.entries) == len(updated_index.entries) + 1
|
||||
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
|
||||
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message
|
||||
|
||||
# verify the same entry is added even when there are multiple duplicate entries
|
||||
error_details = compare_index(updated_index, initial_index)
|
||||
if error_details:
|
||||
pytest.fail(error_details)
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
|
||||
)
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
with open(new_org_file, "w") as f:
|
||||
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
f.write(new_entry)
|
||||
|
||||
data = get_org_files(content_config.org)
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# update embeddings, entries with the newly added note
|
||||
content_config.org.input_files = [f"{new_org_file}"]
|
||||
final_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1
|
||||
assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1
|
||||
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
|
||||
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
|
||||
|
||||
# verify new entry appended to index, without disrupting order or content of existing entries
|
||||
error_details = compare_index(initial_notes_model, final_notes_model)
|
||||
if error_details:
|
||||
pytest.fail(error_details, False)
|
||||
|
||||
# Cleanup
|
||||
# reset input_files in config to empty list
|
||||
content_config.org.input_files = []
|
||||
verify_embeddings(11, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
||||
def test_text_search_setup_github(content_config: ContentConfig, search_models: SearchModels):
|
||||
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
|
||||
# Arrange
|
||||
github_config = GithubConfig.objects.filter(user=default_user).first()
|
||||
# Act
|
||||
# Regenerate github embeddings to test asymmetric setup without caching
|
||||
github_model = text_search.setup(
|
||||
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
|
||||
text_search.setup(
|
||||
GithubToJsonl,
|
||||
{},
|
||||
regenerate=True,
|
||||
user=default_user,
|
||||
config=github_config,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(github_model.entries) > 1
|
||||
embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count()
|
||||
assert embeddings > 1
|
||||
|
||||
|
||||
def compare_index(initial_notes_model, final_notes_model):
|
||||
mismatched_entries, mismatched_embeddings = [], []
|
||||
for index in range(len(initial_notes_model.entries)):
|
||||
if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
|
||||
mismatched_entries.append(index)
|
||||
|
||||
# verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
|
||||
for index in range(len(initial_notes_model.corpus_embeddings)):
|
||||
if not initial_notes_model.corpus_embeddings[index].allclose(final_notes_model.corpus_embeddings[index]):
|
||||
mismatched_embeddings.append(index)
|
||||
|
||||
error_details = ""
|
||||
if mismatched_entries:
|
||||
mismatched_entries_str = ",".join(map(str, mismatched_entries))
|
||||
error_details += f"Entries at {mismatched_entries_str} not equal\n"
|
||||
if mismatched_embeddings:
|
||||
mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
|
||||
error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
|
||||
|
||||
return error_details
|
||||
def verify_embeddings(expected_count, user):
|
||||
embeddings = Embeddings.objects.filter(user=user, file_type="org").count()
|
||||
assert embeddings == expected_count
|
||||
|
||||
@@ -3,68 +3,40 @@ from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
def test_no_word_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = "head tail"
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == False
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_word_exclude_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
q_with_exclude_filter = 'head -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(q_with_exclude_filter)
|
||||
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_word_include_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
query_with_include_filter = 'head +"include_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(query_with_include_filter)
|
||||
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {2, 3}
|
||||
|
||||
|
||||
def test_word_include_and_exclude_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
|
||||
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {2}
|
||||
|
||||
|
||||
def test_get_word_filter_terms():
|
||||
|
||||
Reference in New Issue
Block a user