[Multi-User Part 1]: Enable storage of settings for plaintext files based on user account (#498)

- Partition configuration for indexing local data based on user accounts
- Store indexed data in an underlying postgres db using the `pgvector` extension
- Add migrations for all relevant user data and embeddings generation. Very little performance optimization has been done for the lookup time
- Apply filters using SQL queries
- Start removing many server-level configuration settings
- Configure GitHub test actions to run during any PR. Update the test action to run in a containerized environment with a DB.
- Update the Docker image and docker-compose.yml to work with the new application design
This commit is contained in:
sabaimran
2023-10-26 09:42:29 -07:00
committed by GitHub
parent 963cd165eb
commit 216acf545f
60 changed files with 1827 additions and 1792 deletions

View File

@@ -1,15 +1,19 @@
# External Packages
import os
from copy import deepcopy
from fastapi.testclient import TestClient
from pathlib import Path
import pytest
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI
import factory
import os
from fastapi import FastAPI
app = FastAPI()
# Internal Packages
from app.main import app
from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchModels
@@ -22,8 +26,6 @@ from khoj.utils.rawconfig import (
OpenAIProcessorConfig,
ProcessorConfig,
TextContentConfig,
GithubContentConfig,
GithubRepoConfig,
ImageContentConfig,
SearchConfig,
TextSearchConfig,
@@ -31,11 +33,31 @@ from khoj.utils.rawconfig import (
)
from khoj.utils import state, fs_syncer
from khoj.routers.indexer import configure_content
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from database.models import (
LocalOrgConfig,
LocalMarkdownConfig,
LocalPlaintextConfig,
LocalPdfConfig,
GithubConfig,
KhojUser,
GithubRepoConfig,
)
@pytest.fixture(autouse=True)
def enable_db_access_for_all_tests(db):
pass
class UserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojUser
username = factory.Faker("name")
email = factory.Faker("email")
password = factory.Faker("password")
uuid = factory.Faker("uuid4")
@pytest.fixture(scope="session")
@@ -67,17 +89,28 @@ def search_config() -> SearchConfig:
return search_config
@pytest.mark.django_db
@pytest.fixture
def default_user():
return UserFactory()
@pytest.fixture(scope="session")
def search_models(search_config: SearchConfig):
search_models = SearchModels()
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
@pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
@pytest.fixture
def anyio_backend():
return "asyncio"
@pytest.mark.django_db
@pytest.fixture(scope="function")
def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser):
content_dir = tmp_path_factory.mktemp("content")
# Generate Image Embeddings from Test Images
@@ -92,94 +125,45 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
# Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig(
LocalOrgConfig.objects.create(
input_files=None,
input_filter=["tests/data/org/*.org"],
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
index_heading_entries=False,
user=default_user,
)
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(
OrgToJsonl,
get_sample_data("org"),
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
content_config.plugins = {
"plugin1": TextContentConfig(
input_files=[content_dir.joinpath("notes.jsonl.gz")],
input_filter=None,
compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
)
}
text_search.setup(OrgToJsonl, get_sample_data("org"), regenerate=False, user=default_user)
if os.getenv("GITHUB_PAT_TOKEN"):
content_config.github = GithubContentConfig(
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
repos=[
GithubRepoConfig(
owner="khoj-ai",
name="lantern",
branch="master",
)
],
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
GithubConfig.objects.create(
pat_token=os.getenv("GITHUB_PAT_TOKEN"),
user=default_user,
)
content_config.plaintext = TextContentConfig(
GithubRepoConfig.objects.create(
owner="khoj-ai",
name="lantern",
branch="master",
github_config=GithubConfig.objects.get(user=default_user),
)
LocalPlaintextConfig.objects.create(
input_files=None,
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
compressed_jsonl=content_dir.joinpath("plaintext.jsonl.gz"),
embeddings_file=content_dir.joinpath("plaintext_embeddings.pt"),
)
content_config.github = GithubContentConfig(
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
repos=[
GithubRepoConfig(
owner="khoj-ai",
name="lantern",
branch="master",
)
],
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
)
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(
JsonlToJsonl,
None,
content_config.plugins["plugin1"],
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
user=default_user,
)
return content_config
@pytest.fixture(scope="session")
def md_content_config(tmp_path_factory):
content_dir = tmp_path_factory.mktemp("content")
# Generate Embeddings for Markdown Content
content_config = ContentConfig()
content_config.markdown = TextContentConfig(
def md_content_config():
markdown_config = LocalMarkdownConfig.objects.create(
input_files=None,
input_filter=["tests/data/markdown/*.markdown"],
compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"),
embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
)
return content_config
return markdown_config
@pytest.fixture(scope="session")
@@ -220,19 +204,20 @@ def processor_config_offline_chat(tmp_path_factory):
@pytest.fixture(scope="session")
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
# Initialize app state
state.config.content_type = md_content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
# Index Markdown Content for Search
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
all_files = fs_syncer.collect_files(state.config.content_type)
all_files = fs_syncer.collect_files()
state.content_index = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models
)
# Initialize Processor from Config
state.processor_config = configure_processor(processor_config)
state.anonymous_mode = True
app = FastAPI()
configure_routes(app)
configure_middleware(app)
@@ -241,33 +226,45 @@ def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, p
@pytest.fixture(scope="function")
def client(content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
def fastapi_app():
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
return app
@pytest.fixture(scope="function")
def client(
content_config: ContentConfig,
search_config: SearchConfig,
processor_config: ProcessorConfig,
default_user: KhojUser,
):
state.config.content_type = content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
# These lines help us Mock the Search models for these search types
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
state.content_index.org = text_search.setup(
text_search.setup(
OrgToJsonl,
get_sample_data("org"),
content_config.org,
state.search_models.text_search.bi_encoder,
regenerate=False,
user=default_user,
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
state.content_index.plaintext = text_search.setup(
text_search.setup(
PlaintextToJsonl,
get_sample_data("plaintext"),
content_config.plaintext,
state.search_models.text_search.bi_encoder,
regenerate=False,
user=default_user,
)
state.processor_config = configure_processor(processor_config)
state.anonymous_mode = True
configure_routes(app)
configure_middleware(app)
@@ -288,7 +285,6 @@ def client_offline_chat(
state.SearchType = configure_search_types(state.config)
# Index Markdown Content for Search
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
all_files = fs_syncer.collect_files(state.config.content_type)
@@ -298,6 +294,7 @@ def client_offline_chat(
# Initialize Processor from Config
state.processor_config = configure_processor(processor_config_offline_chat)
state.anonymous_mode = True
configure_routes(app)
configure_middleware(app)
@@ -306,9 +303,11 @@ def client_offline_chat(
@pytest.fixture(scope="function")
def new_org_file(content_config: ContentConfig):
def new_org_file(default_user: KhojUser, content_config: ContentConfig):
# Setup
new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org"
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
input_filters = org_config.input_filter
new_org_file = Path(input_filters[0]).parent / "new_file.org"
new_org_file.touch()
yield new_org_file
@@ -319,11 +318,9 @@ def new_org_file(content_config: ContentConfig):
@pytest.fixture(scope="function")
def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path):
new_org_config = deepcopy(content_config.org)
new_org_config.input_files = [f"{new_org_file}"]
new_org_config.input_filter = None
return new_org_config
def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None)
return LocalOrgConfig.objects.filter(user=default_user).first()
@pytest.fixture(scope="function")

11
tests/data/config.yml vendored
View File

@@ -9,17 +9,6 @@ content-type:
input-filter:
- '*.org'
- ~/notes/*.org
plugins:
content_plugin_1:
compressed-jsonl: content_plugin_1.jsonl.gz
embeddings-file: content_plugin_1_embeddings.pt
input-files:
- content_plugin_1_new.jsonl.gz
content_plugin_2:
compressed-jsonl: content_plugin_2.jsonl.gz
embeddings-file: content_plugin_2_embeddings.pt
input-filter:
- '*2_new.jsonl.gz'
enable-offline-chat: false
search-type:
asymmetric:

View File

@@ -48,14 +48,3 @@ def test_cli_config_from_file():
Path("~/first_from_config.org"),
Path("~/second_from_config.org"),
]
assert len(actual_args.config.content_type.plugins.keys()) == 2
assert actual_args.config.content_type.plugins["content_plugin_1"].input_files == [
Path("content_plugin_1_new.jsonl.gz")
]
assert actual_args.config.content_type.plugins["content_plugin_2"].input_filter == ["*2_new.jsonl.gz"]
assert actual_args.config.content_type.plugins["content_plugin_1"].compressed_jsonl == Path(
"content_plugin_1.jsonl.gz"
)
assert actual_args.config.content_type.plugins["content_plugin_2"].embeddings_file == Path(
"content_plugin_2_embeddings.pt"
)

View File

@@ -2,22 +2,21 @@
from io import BytesIO
from PIL import Image
from urllib.parse import quote
import pytest
# External Packages
from fastapi.testclient import TestClient
import pytest
from fastapi import FastAPI
# Internal Packages
from app.main import app
from khoj.configure import configure_routes, configure_search_types
from khoj.utils import state
from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from database.models import KhojUser
from database.adapters import EmbeddingsAdapters
# Test
@@ -35,7 +34,7 @@ def test_search_with_invalid_content_type(client):
# ----------------------------------------------------------------------------------------------------
def test_search_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plugin1"]:
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion"]:
# Act
response = client.get(f"/api/search?q=random&t={content_type}")
# Assert
@@ -75,7 +74,7 @@ def test_index_update(client):
# ----------------------------------------------------------------------------------------------------
def test_regenerate_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]:
for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]:
# Arrange
files = get_sample_files_data()
headers = {"x-api-key": "secret"}
@@ -102,60 +101,42 @@ def test_regenerate_with_github_fails_without_pat(client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
@pytest.mark.skip(reason="Flaky test on parallel test runs")
def test_get_configured_types_via_api(client):
def test_get_configured_types_via_api(client, sample_org_data):
# Act
response = client.get(f"/api/config/types")
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
enabled_types = EmbeddingsAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
# Assert
assert response.status_code == 200
assert response.json() == ["all", "org", "image", "plaintext", "plugin1"]
assert list(enabled_types) == ["org"]
# ----------------------------------------------------------------------------------------------------
def test_get_configured_types_with_only_plugin_content_config(content_config):
@pytest.mark.django_db(transaction=True)
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data):
# Arrange
config.content_type = ContentConfig()
config.content_type.plugins = content_config.plugins
state.SearchType = configure_search_types(config)
configure_routes(app)
client = TestClient(app)
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
# Act
response = client.get(f"/api/config/types")
# Assert
assert response.status_code == 200
assert response.json() == ["all", "plugin1"]
assert response.json() == ["all", "org", "markdown", "image"]
# ----------------------------------------------------------------------------------------------------
def test_get_configured_types_with_no_plugin_content_config(content_config):
@pytest.mark.django_db(transaction=True)
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
# Arrange
config.content_type = content_config
config.content_type.plugins = None
state.SearchType = configure_search_types(config)
original_config = state.config.content_type
state.config.content_type = None
configure_routes(app)
client = TestClient(app)
# Act
response = client.get(f"/api/config/types")
# Assert
assert response.status_code == 200
assert "plugin1" not in response.json()
# ----------------------------------------------------------------------------------------------------
def test_get_configured_types_with_no_content_config():
# Arrange
config.content_type = ContentConfig()
state.SearchType = configure_search_types(config)
configure_routes(app)
client = TestClient(app)
configure_routes(fastapi_app)
client = TestClient(fastapi_app)
# Act
response = client.get(f"/api/config/types")
@@ -164,6 +145,9 @@ def test_get_configured_types_with_no_content_config():
assert response.status_code == 200
assert response.json() == ["all"]
# Restore
state.config.content_type = original_config
# ----------------------------------------------------------------------------------------------------
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
@@ -192,12 +176,10 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data):
@pytest.mark.django_db(transaction=True)
def test_notes_search(client, search_config: SearchConfig, sample_org_data):
# Arrange
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
user_query = quote("How to git install application?")
# Act
@@ -211,19 +193,15 @@ def test_notes_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_only_filters(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
):
# Arrange
filters = [WordFilter(), FileFilter()]
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
text_search.setup(
OrgToJsonl,
sample_org_data,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
user_query = quote('+"Emacs" file:"*.org"')
@@ -238,15 +216,10 @@ def test_notes_search_with_only_filters(
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_include_filter(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
):
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_include_filter(client, sample_org_data):
# Arrange
filters = [WordFilter()]
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search, regenerate=False, filters=filters
)
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False)
user_query = quote('How to git install application? +"Emacs"')
# Act
@@ -260,19 +233,13 @@ def test_notes_search_with_include_filter(
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_exclude_filter(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
):
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_exclude_filter(client, sample_org_data):
# Arrange
filters = [WordFilter()]
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
text_search.setup(
OrgToJsonl,
sample_org_data,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
user_query = quote('How to git install application? -"clone"')
@@ -286,6 +253,22 @@ def test_notes_search_with_exclude_filter(
assert "clone" not in search_result
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
# Arrange
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user)
user_query = quote("How to git install application?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
# Assert
assert response.status_code == 200
# assert actual response has no data as the default_user is different from the user making the query (anonymous)
assert len(response.json()) == 0
def get_sample_files_data():
return {
"files": ("path/to/filename.org", "* practicing piano", "text/org"),

View File

@@ -1,53 +1,12 @@
# Standard Packages
import re
from datetime import datetime
from math import inf
# External Packages
import pytest
# Internal Packages
from khoj.search_filter.date_filter import DateFilter
from khoj.utils.rawconfig import Entry
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
def test_date_filter():
entries = [
Entry(compiled="Entry with no date", raw="Entry with no date"),
Entry(compiled="April Fools entry: 1984-04-01", raw="April Fools entry: 1984-04-01"),
Entry(compiled="Entry with date:1984-04-02", raw="Entry with date:1984-04-02"),
]
q_with_no_date_filter = "head tail"
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2}
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
assert ret_query == "head tail"
assert entry_indices == set()
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {1}
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {1, 2}
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
@@ -56,8 +15,8 @@ def test_extract_date_range():
datetime(1984, 1, 5, 0, 0, 0).timestamp(),
datetime(1984, 1, 7, 0, 0, 0).timestamp(),
]
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [None, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), None]
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
datetime(1984, 1, 1, 0, 0, 0).timestamp(),
datetime(1984, 1, 2, 0, 0, 0).timestamp(),

View File

@@ -6,97 +6,73 @@ from khoj.utils.rawconfig import Entry
def test_no_file_filter():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = "head tail"
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == False
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_file_filter_with_non_existent_file():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"nonexistent.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {}
def test_single_file_filter():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"file 1.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 2}
def test_file_filter_with_partial_match():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"1.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 2}
def test_file_filter_with_regex_match():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"*.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_multiple_file_filter():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_get_file_filter_terms():
@@ -108,7 +84,7 @@ def test_get_file_filter_terms():
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
# Assert
assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"']
assert filter_terms == ["file 1\\.org", "/path/to/dir/.*\\.org"]
def arrange_content():

View File

@@ -1,78 +0,0 @@
# Internal Packages
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.utils.rawconfig import Entry
def test_process_entries_from_single_input_jsonl(tmp_path):
"Convert multiple jsonl entries from single file to entries."
# Arrange
input_jsonl = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}
{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}
"""
input_jsonl_file = create_file(tmp_path, input_jsonl)
# Act
# Process Each Entry from All Notes Files
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file])
entries = list(map(Entry.from_dict, input_jsons))
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
# Assert
assert len(entries) == 2
assert output_jsonl == input_jsonl
def test_process_entries_from_multiple_input_jsonls(tmp_path):
"Convert multiple jsonl entries from single file to entries."
# Arrange
input_jsonl_1 = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}"""
input_jsonl_2 = """{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}"""
input_jsonl_file_1 = create_file(tmp_path, input_jsonl_1, filename="input1.jsonl")
input_jsonl_file_2 = create_file(tmp_path, input_jsonl_2, filename="input2.jsonl")
# Act
# Process Each Entry from All Notes Files
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file_1, input_jsonl_file_2])
entries = list(map(Entry.from_dict, input_jsons))
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
# Assert
assert len(entries) == 2
assert output_jsonl == f"{input_jsonl_1}\n{input_jsonl_2}\n"
def test_get_jsonl_files(tmp_path):
"Ensure JSONL files specified via input-filter, input-files extracted"
# Arrange
# Include via input-filter globs
group1_file1 = create_file(tmp_path, filename="group1-file1.jsonl")
group1_file2 = create_file(tmp_path, filename="group1-file2.jsonl")
group2_file1 = create_file(tmp_path, filename="group2-file1.jsonl")
group2_file2 = create_file(tmp_path, filename="group2-file2.jsonl")
# Include via input-file field
file1 = create_file(tmp_path, filename="notes.jsonl")
# Not included by any filter
create_file(tmp_path, filename="not-included-jsonl.jsonl")
create_file(tmp_path, filename="not-included-text.txt")
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
# Setup input-files, input-filters
input_files = [tmp_path / "notes.jsonl"]
input_filter = [tmp_path / "group1*.jsonl", tmp_path / "group2*.jsonl"]
# Act
extracted_org_files = JsonlToJsonl.get_jsonl_files(input_files, input_filter)
# Assert
assert len(extracted_org_files) == 5
assert extracted_org_files == expected_files
# Helper Functions
def create_file(tmp_path, entry=None, filename="test.jsonl"):
jsonl_file = tmp_path / filename
jsonl_file.touch()
if entry:
jsonl_file.write_text(entry)
return jsonl_file

View File

@@ -4,7 +4,7 @@ import os
# Internal Packages
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import Entry
from khoj.utils.fs_syncer import get_org_files
@@ -63,7 +63,7 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
# Split each entry from specified Org files by max words
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
TextToJsonl.split_entries_by_max_tokens(
TextEmbeddings.split_entries_by_max_tokens(
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
)
)
@@ -86,7 +86,7 @@ def test_entry_split_drops_large_words():
# Act
# Split entry by max words and drop words larger than max word length
processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0]
processed_entry = TextEmbeddings.split_entries_by_max_tokens([entry], max_word_length=5)[0]
# Assert
# "Heading" dropped from compiled version because its over the set max word limit

View File

@@ -7,6 +7,7 @@ from pathlib import Path
from khoj.utils.fs_syncer import get_plaintext_files
from khoj.utils.rawconfig import TextContentConfig
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from database.models import LocalPlaintextConfig, KhojUser
def test_plaintext_file(tmp_path):
@@ -91,11 +92,12 @@ def test_get_plaintext_files(tmp_path):
assert set(extracted_plaintext_files.keys()) == set(expected_files)
def test_parse_html_plaintext_file(content_config):
def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
"Ensure HTML files are parsed correctly"
# Arrange
# Setup input-files, input-filters
extracted_plaintext_files = get_plaintext_files(content_config.plaintext)
config = LocalPlaintextConfig.objects.filter(user=default_user).first()
extracted_plaintext_files = get_plaintext_files(config=config)
# Act
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files)

View File

@@ -3,23 +3,30 @@ import logging
import locale
from pathlib import Path
import os
import asyncio
# External Packages
import pytest
# Internal Packages
from khoj.utils.state import content_index, search_models
from khoj.search_type import text_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.utils.config import SearchModels
from khoj.utils.fs_syncer import get_org_files
from khoj.utils.fs_syncer import get_org_files, collect_files
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
logger = logging.getLogger(__name__)
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
# Test
# ----------------------------------------------------------------------------------------------------
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig):
@pytest.mark.django_db
def test_text_search_setup_with_missing_file_raises_error(
org_config_with_only_new_file: LocalOrgConfig, search_config: SearchConfig
):
# Arrange
# Ensure file mentioned in org.input-files is missing
single_new_file = Path(org_config_with_only_new_file.input_files[0])
@@ -32,98 +39,126 @@ def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_n
# ----------------------------------------------------------------------------------------------------
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path: Path):
@pytest.mark.django_db
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser):
# Arrange
orgfile = tmp_path / "directory.org" / "file.org"
orgfile.parent.mkdir()
with open(orgfile, "w") as f:
f.write("* Heading\n- List item\n")
org_content_config = TextContentConfig(
input_filter=[f"{tmp_path}/**/*"], compressed_jsonl="test.jsonl", embeddings_file="test.pt"
LocalOrgConfig.objects.create(
input_filter=[f"{tmp_path}/**/*"],
input_files=None,
user=default_user,
)
org_files = collect_files(user=default_user)["org"]
# Act
# should not raise IsADirectoryError and return orgfile
assert get_org_files(org_content_config) == {f"{orgfile}": "* Heading\n- List item\n"}
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_with_empty_file_raises_error(
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange
data = get_org_files(org_config_with_only_new_file)
# Act
# Generate notes embeddings during asymmetric setup
with pytest.raises(ValueError, match=r"^No valid entries found*"):
text_search.setup(OrgToJsonl, data, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
assert "Created 0 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
verify_embeddings(0, default_user)
# ----------------------------------------------------------------------------------------------------
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
@pytest.mark.django_db
def test_text_search_setup(content_config, default_user: KhojUser, caplog):
# Arrange
data = get_org_files(content_config.org)
# Act
# Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# Assert
assert len(notes_model.entries) == 10
assert len(notes_model.corpus_embeddings) == 10
assert "Deleting all embeddings for file type org" in caplog.records[1].message
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
# ----------------------------------------------------------------------------------------------------
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog):
@pytest.mark.django_db
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
# Arrange
caplog.set_level(logging.INFO, logger="khoj")
data = get_org_files(content_config.org)
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
# Act
# Generate initial notes embeddings during asymmetric setup
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert
assert "Creating index from scratch." in initial_logs
assert "Creating index from scratch." not in final_logs
assert "Deleting all embeddings for file type org" in initial_logs
assert "Deleting all embeddings for file type org" not in final_logs
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
@pytest.mark.anyio
async def test_text_search(content_config: ContentConfig, search_config: SearchConfig):
# @pytest.mark.asyncio
async def test_text_search(search_config: SearchConfig):
# Arrange
data = get_org_files(content_config.org)
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
default_user = await KhojUser.objects.acreate(
username="test_user", password="test_password", email="test@example.com"
)
# Arrange
org_config = await LocalOrgConfig.objects.acreate(
input_files=None,
input_filter=["tests/data/org/*.org"],
index_heading_entries=False,
user=default_user,
)
data = get_org_files(org_config)
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
text_search.setup,
OrgToJsonl,
data,
True,
True,
default_user,
)
query = "How to git install application?"
# Act
hits, entries = await text_search.query(
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
)
results = text_search.collate_results(hits, entries, count=1)
hits = await text_search.query(default_user, query)
# Assert
results = text_search.collate_results(hits)
results = sorted(results, key=lambda x: float(x.score))[:1]
# search results should contain "git clone" entry
search_result = results[0].entry
assert "git clone" in search_result
# ----------------------------------------------------------------------------------------------------
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
@pytest.mark.django_db
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256
@@ -137,47 +172,46 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# Act
# reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
# Assert
# verify newly added org-mode entry is split by max tokens
assert len(initial_notes_model.entries) == 2
assert len(initial_notes_model.corpus_embeddings) == 2
record = caplog.records[1]
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
# ----------------------------------------------------------------------------------------------------
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
@pytest.mark.django_db
def test_entry_chunking_by_max_tokens_not_full_corpus(
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
data = {
"readme.org": """
* Khoj
/Allow natural language search on user content like notes, images using transformer based models/
/Allow natural language search on user content like notes, images using transformer based models/
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
** Dependencies
- Python3
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
- Python3
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
** Install
#+begin_src shell
git clone https://github.com/khoj-ai/khoj && cd khoj
conda env create -f environment.yml
conda activate khoj
#+end_src"""
#+begin_src shell
git clone https://github.com/khoj-ai/khoj && cd khoj
conda env create -f environment.yml
conda activate khoj
#+end_src"""
}
text_search.setup(
OrgToJsonl,
data,
org_config_with_only_new_file,
search_models.text_search.bi_encoder,
regenerate=False,
user=default_user,
)
max_tokens = 256
@@ -191,64 +225,57 @@ def test_entry_chunking_by_max_tokens_not_full_corpus(
# Act
# reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(
OrgToJsonl,
data,
org_config_with_only_new_file,
search_models.text_search.bi_encoder,
regenerate=False,
full_corpus=False,
)
with caplog.at_level(logging.INFO):
text_search.setup(
OrgToJsonl,
data,
regenerate=False,
full_corpus=False,
user=default_user,
)
record = caplog.records[1]
# Assert
# verify newly added org-mode entry is split by max tokens
assert len(initial_notes_model.entries) == 5
assert len(initial_notes_model.corpus_embeddings) == 5
assert "Created 2 new embeddings. Deleted 0 embeddings for user " in record.message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_regenerate_index_with_new_entry(
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
):
# Arrange
data = get_org_files(content_config.org)
initial_notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
# append org-mode entry to first org input file in config
content_config.org.input_files = [f"{new_org_file}"]
org_config.input_files = [f"{new_org_file}"]
with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
data = get_org_files(content_config.org)
data = get_org_files(org_config)
# Act
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# Assert
assert len(regenerated_notes_model.entries) == 11
assert len(regenerated_notes_model.corpus_embeddings) == 11
# verify new entry appended to index, without disrupting order or content of existing entries
error_details = compare_index(initial_notes_model, regenerated_notes_model)
if error_details:
pytest.fail(error_details, False)
# Cleanup
# reset input_files in config to empty list
content_config.org.input_files = []
assert "Created 11 new embeddings. Deleted 10 embeddings for user " in caplog.records[-1].message
verify_embeddings(11, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_update_index_with_duplicate_entries_in_stable_order(
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
@@ -262,30 +289,26 @@ def test_update_index_with_duplicate_entries_in_stable_order(
# Act
# load embeddings, entries, notes model after adding new org-mode file
initial_index = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model after adding new org-mode file
updated_index = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert len(initial_index.entries) == len(updated_index.entries) == 1
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1
assert "Created 1 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert "Created 0 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
# verify the same entry is added even when there are multiple duplicate entries
error_details = compare_index(initial_index, updated_index)
if error_details:
pytest.fail(error_details)
verify_embeddings(1, default_user)
# ----------------------------------------------------------------------------------------------------
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
@pytest.mark.django_db
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
# Arrange
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
@@ -296,9 +319,8 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
data = get_org_files(org_config_with_only_new_file)
# load embeddings, entries, notes model after adding new org file with 2 entries
initial_index = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f:
@@ -307,87 +329,65 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
data = get_org_files(org_config_with_only_new_file)
# Act
updated_index = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert len(initial_index.entries) == len(updated_index.entries) + 1
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
assert "Created 2 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert "Created 0 new embeddings. Deleted 1 embeddings for user " in caplog.records[4].message
# verify the same entry is added even when there are multiple duplicate entries
error_details = compare_index(updated_index, initial_index)
if error_details:
pytest.fail(error_details)
verify_embeddings(1, default_user)
# ----------------------------------------------------------------------------------------------------
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
@pytest.mark.django_db
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
# Arrange
data = get_org_files(content_config.org)
initial_notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
)
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=True, user=default_user)
# append org-mode entry to first org input file in config
with open(new_org_file, "w") as f:
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
f.write(new_entry)
data = get_org_files(content_config.org)
data = get_org_files(org_config)
# Act
# update embeddings, entries with the newly added note
content_config.org.input_files = [f"{new_org_file}"]
final_notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToJsonl, data, regenerate=False, user=default_user)
# Assert
assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1
assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1
assert "Created 10 new embeddings. Deleted 3 embeddings for user " in caplog.records[2].message
assert "Created 1 new embeddings. Deleted 0 embeddings for user " in caplog.records[4].message
# verify new entry appended to index, without disrupting order or content of existing entries
error_details = compare_index(initial_notes_model, final_notes_model)
if error_details:
pytest.fail(error_details, False)
# Cleanup
# reset input_files in config to empty list
content_config.org.input_files = []
verify_embeddings(11, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
def test_text_search_setup_github(content_config: ContentConfig, search_models: SearchModels):
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
# Arrange
github_config = GithubConfig.objects.filter(user=default_user).first()
# Act
# Regenerate github embeddings to test asymmetric setup without caching
github_model = text_search.setup(
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
text_search.setup(
GithubToJsonl,
{},
regenerate=True,
user=default_user,
config=github_config,
)
# Assert
assert len(github_model.entries) > 1
embeddings = Embeddings.objects.filter(user=default_user, file_type="github").count()
assert embeddings > 1
def compare_index(initial_notes_model, final_notes_model):
mismatched_entries, mismatched_embeddings = [], []
for index in range(len(initial_notes_model.entries)):
if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
mismatched_entries.append(index)
# verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
for index in range(len(initial_notes_model.corpus_embeddings)):
if not initial_notes_model.corpus_embeddings[index].allclose(final_notes_model.corpus_embeddings[index]):
mismatched_embeddings.append(index)
error_details = ""
if mismatched_entries:
mismatched_entries_str = ",".join(map(str, mismatched_entries))
error_details += f"Entries at {mismatched_entries_str} not equal\n"
if mismatched_embeddings:
mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
return error_details
def verify_embeddings(expected_count, user):
embeddings = Embeddings.objects.filter(user=user, file_type="org").count()
assert embeddings == expected_count

View File

@@ -3,68 +3,40 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.utils.rawconfig import Entry
def test_no_word_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
q_with_no_filter = "head tail"
# Act
can_filter = word_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == False
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_word_exclude_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
q_with_exclude_filter = 'head -"exclude_word" tail'
# Act
can_filter = word_filter.can_filter(q_with_exclude_filter)
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 2}
def test_word_include_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
query_with_include_filter = 'head +"include_word" tail'
# Act
can_filter = word_filter.can_filter(query_with_include_filter)
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {2, 3}
def test_word_include_and_exclude_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {2}
def test_get_word_filter_terms():