mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 13:18:18 +00:00
Resolve merge conflicts
This commit is contained in:
@@ -1,61 +1,65 @@
|
||||
# External Packages
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from fastapi.testclient import TestClient
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi import FastAPI
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from khoj.main import app
|
||||
from khoj.configure import configure_processor, configure_routes, configure_search_types
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||
from khoj.configure import configure_routes, configure_search_types, configure_middleware
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.constants import web_directory
|
||||
from khoj.utils.helpers import resolve_absolute_path
|
||||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
ConversationProcessorConfig,
|
||||
OfflineChatProcessorConfig,
|
||||
OpenAIProcessorConfig,
|
||||
ProcessorConfig,
|
||||
TextContentConfig,
|
||||
GithubContentConfig,
|
||||
GithubRepoConfig,
|
||||
ImageContentConfig,
|
||||
SearchConfig,
|
||||
TextSearchConfig,
|
||||
ImageSearchConfig,
|
||||
)
|
||||
from khoj.utils import state, fs_syncer
|
||||
from khoj.routers.indexer import configure_content
|
||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.processor.org_mode.org_to_entries import OrgToEntries
|
||||
from database.models import (
|
||||
KhojApiUser,
|
||||
LocalOrgConfig,
|
||||
LocalMarkdownConfig,
|
||||
LocalPlaintextConfig,
|
||||
GithubConfig,
|
||||
KhojUser,
|
||||
GithubRepoConfig,
|
||||
)
|
||||
|
||||
from tests.helpers import (
|
||||
UserFactory,
|
||||
ChatModelOptionsFactory,
|
||||
OpenAIProcessorConversationConfigFactory,
|
||||
OfflineChatProcessorConversationConfigFactory,
|
||||
UserConversationProcessorConfigFactory,
|
||||
SubscriptionFactory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_db_access_for_all_tests(db):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_config() -> SearchConfig:
|
||||
state.embeddings_model = EmbeddingsModel()
|
||||
state.cross_encoder_model = CrossEncoderModel()
|
||||
|
||||
model_dir = resolve_absolute_path("~/.khoj/search")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
search_config = SearchConfig()
|
||||
|
||||
search_config.symmetric = TextSearchConfig(
|
||||
encoder="sentence-transformers/all-MiniLM-L6-v2",
|
||||
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
model_directory=model_dir / "symmetric/",
|
||||
encoder_type=None,
|
||||
)
|
||||
|
||||
search_config.asymmetric = TextSearchConfig(
|
||||
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
||||
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
model_directory=model_dir / "asymmetric/",
|
||||
encoder_type=None,
|
||||
)
|
||||
|
||||
search_config.image = ImageSearchConfig(
|
||||
encoder="sentence-transformers/clip-ViT-B-32",
|
||||
model_directory=model_dir / "image/",
|
||||
@@ -65,17 +69,102 @@ def search_config() -> SearchConfig:
|
||||
return search_config
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_user():
|
||||
user = UserFactory()
|
||||
SubscriptionFactory(user=user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_user2():
|
||||
if KhojUser.objects.filter(username="default").exists():
|
||||
return KhojUser.objects.get(username="default")
|
||||
|
||||
user = KhojUser.objects.create(
|
||||
username="default",
|
||||
email="default@example.com",
|
||||
password="default",
|
||||
)
|
||||
SubscriptionFactory(user=user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_user3():
|
||||
"""
|
||||
This user should not have any data associated with it
|
||||
"""
|
||||
if KhojUser.objects.filter(username="default3").exists():
|
||||
return KhojUser.objects.get(username="default3")
|
||||
|
||||
user = KhojUser.objects.create(
|
||||
username="default3",
|
||||
email="default3@example.com",
|
||||
password="default3",
|
||||
)
|
||||
SubscriptionFactory(user=user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def api_user(default_user):
|
||||
if KhojApiUser.objects.filter(user=default_user).exists():
|
||||
return KhojApiUser.objects.get(user=default_user)
|
||||
|
||||
return KhojApiUser.objects.create(
|
||||
user=default_user,
|
||||
name="api-key",
|
||||
token="kk-secret",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def api_user2(default_user2):
|
||||
if KhojApiUser.objects.filter(user=default_user2).exists():
|
||||
return KhojApiUser.objects.get(user=default_user2)
|
||||
|
||||
return KhojApiUser.objects.create(
|
||||
user=default_user2,
|
||||
name="api-key",
|
||||
token="kk-diff-secret",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def api_user3(default_user3):
|
||||
if KhojApiUser.objects.filter(user=default_user3).exists():
|
||||
return KhojApiUser.objects.get(user=default_user3)
|
||||
|
||||
return KhojApiUser.objects.create(
|
||||
user=default_user3,
|
||||
name="api-key",
|
||||
token="kk-diff-secret-3",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_models(search_config: SearchConfig):
|
||||
search_models = SearchModels()
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
return search_models
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
|
||||
@pytest.fixture
|
||||
def anyio_backend():
|
||||
return "asyncio"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture(scope="function")
|
||||
def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser):
|
||||
content_dir = tmp_path_factory.mktemp("content")
|
||||
|
||||
# Generate Image Embeddings from Test Images
|
||||
@@ -90,217 +179,188 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
|
||||
|
||||
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
|
||||
|
||||
# Generate Notes Embeddings from Test Notes
|
||||
content_config.org = TextContentConfig(
|
||||
LocalOrgConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/org/*.org"],
|
||||
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
|
||||
index_heading_entries=False,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
get_sample_data("org"),
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
content_config.plugins = {
|
||||
"plugin1": TextContentConfig(
|
||||
input_files=[content_dir.joinpath("notes.jsonl.gz")],
|
||||
input_filter=None,
|
||||
compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
|
||||
)
|
||||
}
|
||||
text_search.setup(OrgToEntries, get_sample_data("org"), regenerate=False, user=default_user)
|
||||
|
||||
if os.getenv("GITHUB_PAT_TOKEN"):
|
||||
content_config.github = GithubContentConfig(
|
||||
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
|
||||
repos=[
|
||||
GithubRepoConfig(
|
||||
owner="khoj-ai",
|
||||
name="lantern",
|
||||
branch="master",
|
||||
)
|
||||
],
|
||||
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
|
||||
GithubConfig.objects.create(
|
||||
pat_token=os.getenv("GITHUB_PAT_TOKEN"),
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
content_config.plaintext = TextContentConfig(
|
||||
GithubRepoConfig.objects.create(
|
||||
owner="khoj-ai",
|
||||
name="lantern",
|
||||
branch="master",
|
||||
github_config=GithubConfig.objects.get(user=default_user),
|
||||
)
|
||||
|
||||
LocalPlaintextConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
|
||||
compressed_jsonl=content_dir.joinpath("plaintext.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("plaintext_embeddings.pt"),
|
||||
)
|
||||
|
||||
content_config.github = GithubContentConfig(
|
||||
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
|
||||
repos=[
|
||||
GithubRepoConfig(
|
||||
owner="khoj-ai",
|
||||
name="lantern",
|
||||
branch="master",
|
||||
)
|
||||
],
|
||||
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
|
||||
)
|
||||
|
||||
filters = [DateFilter(), WordFilter(), FileFilter()]
|
||||
text_search.setup(
|
||||
JsonlToJsonl,
|
||||
None,
|
||||
content_config.plugins["plugin1"],
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
return content_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def md_content_config(tmp_path_factory):
|
||||
content_dir = tmp_path_factory.mktemp("content")
|
||||
|
||||
# Generate Embeddings for Markdown Content
|
||||
content_config = ContentConfig()
|
||||
content_config.markdown = TextContentConfig(
|
||||
def md_content_config():
|
||||
markdown_config = LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/markdown/*.markdown"],
|
||||
compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"),
|
||||
embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
|
||||
)
|
||||
|
||||
return content_config
|
||||
return markdown_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def processor_config(tmp_path_factory):
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
processor_dir = tmp_path_factory.mktemp("processor")
|
||||
|
||||
# The conversation processor is the only configured processor
|
||||
# It needs an OpenAI API key to work.
|
||||
if not openai_api_key:
|
||||
return
|
||||
|
||||
# Setup conversation processor, if OpenAI API key is set
|
||||
processor_config = ProcessorConfig()
|
||||
processor_config.conversation = ConversationProcessorConfig(
|
||||
openai=OpenAIProcessorConfig(api_key=openai_api_key),
|
||||
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
|
||||
)
|
||||
|
||||
return processor_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def processor_config_offline_chat(tmp_path_factory):
|
||||
processor_dir = tmp_path_factory.mktemp("processor")
|
||||
|
||||
# Setup conversation processor
|
||||
processor_config = ProcessorConfig()
|
||||
offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True, chat_model="mistral-7b-instruct-v0.1.Q4_0.gguf")
|
||||
processor_config.conversation = ConversationProcessorConfig(
|
||||
offline_chat=offline_chat,
|
||||
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
|
||||
)
|
||||
|
||||
return processor_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
|
||||
@pytest.fixture(scope="function")
|
||||
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
||||
# Initialize app state
|
||||
state.config.content_type = md_content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.SearchType = configure_search_types()
|
||||
|
||||
LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/markdown/*.markdown"],
|
||||
user=default_user2,
|
||||
)
|
||||
|
||||
# Index Markdown Content for Search
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
all_files = fs_syncer.collect_files(state.config.content_type)
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models
|
||||
all_files = fs_syncer.collect_files(user=default_user2)
|
||||
state.content_index, _ = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
|
||||
)
|
||||
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(processor_config)
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
|
||||
OpenAIProcessorConversationConfigFactory()
|
||||
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
|
||||
|
||||
state.anonymous_mode = True
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
|
||||
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
|
||||
# Initialize app state
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types()
|
||||
|
||||
# Initialize Processor from Config
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
|
||||
OpenAIProcessorConversationConfigFactory()
|
||||
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
|
||||
|
||||
state.anonymous_mode = True
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def fastapi_app():
|
||||
app = FastAPI()
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(
|
||||
content_config: ContentConfig,
|
||||
search_config: SearchConfig,
|
||||
api_user: KhojApiUser,
|
||||
):
|
||||
state.config.content_type = content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.SearchType = configure_search_types()
|
||||
state.embeddings_model = EmbeddingsModel()
|
||||
state.cross_encoder_model = CrossEncoderModel()
|
||||
|
||||
# These lines help us Mock the Search models for these search types
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
state.content_index.org = text_search.setup(
|
||||
OrgToJsonl,
|
||||
text_search.setup(
|
||||
OrgToEntries,
|
||||
get_sample_data("org"),
|
||||
content_config.org,
|
||||
state.search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
user=api_user.user,
|
||||
)
|
||||
state.content_index.image = image_search.setup(
|
||||
content_config.image, state.search_models.image_search, regenerate=False
|
||||
)
|
||||
state.content_index.plaintext = text_search.setup(
|
||||
PlaintextToJsonl,
|
||||
text_search.setup(
|
||||
PlaintextToEntries,
|
||||
get_sample_data("plaintext"),
|
||||
content_config.plaintext,
|
||||
state.search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
user=api_user.user,
|
||||
)
|
||||
|
||||
state.processor_config = configure_processor(processor_config)
|
||||
state.anonymous_mode = False
|
||||
|
||||
app = FastAPI()
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client_offline_chat(
|
||||
search_config: SearchConfig,
|
||||
processor_config_offline_chat: ProcessorConfig,
|
||||
content_config: ContentConfig,
|
||||
md_content_config,
|
||||
):
|
||||
def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
|
||||
# Initialize app state
|
||||
state.config.content_type = md_content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.SearchType = configure_search_types()
|
||||
|
||||
# Index Markdown Content for Search
|
||||
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/markdown/*.markdown"],
|
||||
user=default_user2,
|
||||
)
|
||||
|
||||
all_files = fs_syncer.collect_files(state.config.content_type)
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models
|
||||
all_files = fs_syncer.collect_files(user=default_user2)
|
||||
configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
|
||||
)
|
||||
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(processor_config_offline_chat)
|
||||
OfflineChatProcessorConversationConfigFactory(enabled=True)
|
||||
UserConversationProcessorConfigFactory(user=default_user2)
|
||||
|
||||
state.anonymous_mode = True
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def new_org_file(content_config: ContentConfig):
|
||||
def new_org_file(default_user: KhojUser, content_config: ContentConfig):
|
||||
# Setup
|
||||
new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org"
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
input_filters = org_config.input_filter
|
||||
new_org_file = Path(input_filters[0]).parent / "new_file.org"
|
||||
new_org_file.touch()
|
||||
|
||||
yield new_org_file
|
||||
@@ -311,11 +371,9 @@ def new_org_file(content_config: ContentConfig):
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path):
|
||||
new_org_config = deepcopy(content_config.org)
|
||||
new_org_config.input_files = [f"{new_org_file}"]
|
||||
new_org_config.input_filter = None
|
||||
return new_org_config
|
||||
def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
|
||||
LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None)
|
||||
return LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
||||
13
tests/data/config.yml
vendored
13
tests/data/config.yml
vendored
@@ -9,20 +9,9 @@ content-type:
|
||||
input-filter:
|
||||
- '*.org'
|
||||
- ~/notes/*.org
|
||||
plugins:
|
||||
content_plugin_1:
|
||||
compressed-jsonl: content_plugin_1.jsonl.gz
|
||||
embeddings-file: content_plugin_1_embeddings.pt
|
||||
input-files:
|
||||
- content_plugin_1_new.jsonl.gz
|
||||
content_plugin_2:
|
||||
compressed-jsonl: content_plugin_2.jsonl.gz
|
||||
embeddings-file: content_plugin_2_embeddings.pt
|
||||
input-filter:
|
||||
- '*2_new.jsonl.gz'
|
||||
enable-offline-chat: false
|
||||
search-type:
|
||||
asymmetric:
|
||||
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||
encoder: sentence-transformers/msmarco-MiniLM-L-6-v3
|
||||
version: 0.10.1
|
||||
version: 0.15.0
|
||||
|
||||
BIN
tests/data/pdf/ocr_samples.pdf
vendored
Normal file
BIN
tests/data/pdf/ocr_samples.pdf
vendored
Normal file
Binary file not shown.
92
tests/helpers.py
Normal file
92
tests/helpers.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import factory
|
||||
import os
|
||||
|
||||
from database.models import (
|
||||
KhojUser,
|
||||
KhojApiUser,
|
||||
ChatModelOptions,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
SearchModelConfig,
|
||||
UserConversationConfig,
|
||||
Conversation,
|
||||
Subscription,
|
||||
)
|
||||
|
||||
|
||||
class UserFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = KhojUser
|
||||
|
||||
username = factory.Faker("name")
|
||||
email = factory.Faker("email")
|
||||
password = factory.Faker("password")
|
||||
uuid = factory.Faker("uuid4")
|
||||
|
||||
|
||||
class ApiUserFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = KhojApiUser
|
||||
|
||||
user = None
|
||||
name = factory.Faker("name")
|
||||
token = factory.Faker("password")
|
||||
|
||||
|
||||
class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = ChatModelOptions
|
||||
|
||||
max_prompt_size = 2000
|
||||
tokenizer = None
|
||||
chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
||||
model_type = "offline"
|
||||
|
||||
|
||||
class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = UserConversationConfig
|
||||
|
||||
user = factory.SubFactory(UserFactory)
|
||||
setting = factory.SubFactory(ChatModelOptionsFactory)
|
||||
|
||||
|
||||
class OfflineChatProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = OfflineChatProcessorConversationConfig
|
||||
|
||||
enabled = True
|
||||
|
||||
|
||||
class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = OpenAIProcessorConversationConfig
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
class ConversationFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = Conversation
|
||||
|
||||
user = factory.SubFactory(UserFactory)
|
||||
|
||||
|
||||
class SearchModelFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = SearchModelConfig
|
||||
|
||||
name = "default"
|
||||
model_type = "text"
|
||||
bi_encoder = "thenlper/gte-small"
|
||||
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
|
||||
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = Subscription
|
||||
|
||||
user = factory.SubFactory(UserFactory)
|
||||
type = "standard"
|
||||
is_recurring = False
|
||||
renewal_date = "2100-04-01"
|
||||
@@ -25,7 +25,7 @@ def test_cli_invalid_config_file_path():
|
||||
non_existent_config_file = f"non-existent-khoj-{random()}.yml"
|
||||
|
||||
# Act
|
||||
actual_args = cli([f"-c={non_existent_config_file}"])
|
||||
actual_args = cli([f"--config-file={non_existent_config_file}"])
|
||||
|
||||
# Assert
|
||||
assert actual_args.config_file == resolve_absolute_path(non_existent_config_file)
|
||||
@@ -35,7 +35,7 @@ def test_cli_invalid_config_file_path():
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_cli_config_from_file():
|
||||
# Act
|
||||
actual_args = cli(["-c=tests/data/config.yml", "--regenerate", "-vvv"])
|
||||
actual_args = cli(["--config-file=tests/data/config.yml", "--regenerate", "-vvv"])
|
||||
|
||||
# Assert
|
||||
assert actual_args.config_file == resolve_absolute_path(Path("tests/data/config.yml"))
|
||||
@@ -48,14 +48,3 @@ def test_cli_config_from_file():
|
||||
Path("~/first_from_config.org"),
|
||||
Path("~/second_from_config.org"),
|
||||
]
|
||||
assert len(actual_args.config.content_type.plugins.keys()) == 2
|
||||
assert actual_args.config.content_type.plugins["content_plugin_1"].input_files == [
|
||||
Path("content_plugin_1_new.jsonl.gz")
|
||||
]
|
||||
assert actual_args.config.content_type.plugins["content_plugin_2"].input_filter == ["*2_new.jsonl.gz"]
|
||||
assert actual_args.config.content_type.plugins["content_plugin_1"].compressed_jsonl == Path(
|
||||
"content_plugin_1.jsonl.gz"
|
||||
)
|
||||
assert actual_args.config.content_type.plugins["content_plugin_2"].embeddings_file == Path(
|
||||
"content_plugin_2_embeddings.pt"
|
||||
)
|
||||
|
||||
@@ -2,69 +2,135 @@
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from urllib.parse import quote
|
||||
|
||||
import pytest
|
||||
|
||||
# External Packages
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.main import app
|
||||
from khoj.configure import configure_routes, configure_search_types
|
||||
from khoj.utils import state
|
||||
from khoj.utils.state import search_models, content_index, config
|
||||
from khoj.search_type import text_search, image_search
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.processor.org_mode.org_to_entries import OrgToEntries
|
||||
from database.models import KhojUser, KhojApiUser
|
||||
from database.adapters import EntryAdapters
|
||||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_search_with_invalid_content_type(client):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_search_with_no_auth_key(client):
|
||||
# Arrange
|
||||
user_query = quote("How to call Khoj from Emacs?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&t=invalid_content_type")
|
||||
response = client.get(f"/api/search?q={user_query}")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_search_with_invalid_auth_key(client):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer invalid-token"}
|
||||
user_query = quote("How to call Khoj from Emacs?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_search_with_invalid_content_type(client):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
user_query = quote("How to call Khoj from Emacs?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&t=invalid_content_type", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_search_with_valid_content_type(client):
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plugin1"]:
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plaintext"]:
|
||||
# Act
|
||||
response = client.get(f"/api/search?q=random&t={content_type}")
|
||||
response = client.get(f"/api/search?q=random&t={content_type}", headers=headers)
|
||||
# Assert
|
||||
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update_with_no_auth_key(client):
|
||||
# Arrange
|
||||
files = get_sample_files_data()
|
||||
|
||||
# Act
|
||||
response = client.post("/api/v1/index/update", files=files)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update_with_invalid_auth_key(client):
|
||||
# Arrange
|
||||
files = get_sample_files_data()
|
||||
headers = {"Authorization": "Bearer kk-invalid-token"}
|
||||
|
||||
# Act
|
||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_update_with_invalid_content_type(client):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/update?t=invalid_content_type")
|
||||
response = client.get(f"/api/update?t=invalid_content_type", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_regenerate_with_invalid_content_type(client):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/update?force=true&t=invalid_content_type")
|
||||
response = client.get(f"/api/update?force=true&t=invalid_content_type", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update(client):
|
||||
# Arrange
|
||||
files = get_sample_files_data()
|
||||
headers = {"x-api-key": "secret"}
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
|
||||
# Act
|
||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
||||
@@ -74,88 +140,75 @@ def test_index_update(client):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_regenerate_with_valid_content_type(client):
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]:
|
||||
for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]:
|
||||
# Arrange
|
||||
files = get_sample_files_data()
|
||||
headers = {"x-api-key": "secret"}
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
|
||||
# Act
|
||||
response = client.post(f"/api/v1/index/update?t={content_type}", files=files, headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_regenerate_with_github_fails_without_pat(client):
|
||||
# Act
|
||||
response = client.get(f"/api/update?force=true&t=github")
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
response = client.get(f"/api/update?force=true&t=github", headers=headers)
|
||||
|
||||
# Arrange
|
||||
files = get_sample_files_data()
|
||||
headers = {"x-api-key": "secret"}
|
||||
|
||||
# Act
|
||||
response = client.post(f"/api/v1/index/update?t=github", files=files, headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.skip(reason="Flaky test on parallel test runs")
|
||||
def test_get_configured_types_via_api(client):
|
||||
@pytest.mark.django_db
|
||||
def test_get_configured_types_via_api(client, sample_org_data):
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False)
|
||||
|
||||
enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["all", "org", "image", "plaintext", "plugin1"]
|
||||
assert list(enabled_types) == ["org"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_get_configured_types_with_only_plugin_content_config(content_config):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
config.content_type = ContentConfig()
|
||||
config.content_type.plugins = content_config.plugins
|
||||
state.SearchType = configure_search_types(config)
|
||||
|
||||
configure_routes(app)
|
||||
client = TestClient(app)
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
response = client.get(f"/api/config/types", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["all", "plugin1"]
|
||||
assert response.json() == ["all", "org", "image", "plaintext"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_get_configured_types_with_no_plugin_content_config(content_config):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
|
||||
# Arrange
|
||||
config.content_type = content_config
|
||||
config.content_type.plugins = None
|
||||
state.SearchType = configure_search_types(config)
|
||||
state.anonymous_mode = True
|
||||
if state.config and state.config.content_type:
|
||||
state.config.content_type = None
|
||||
state.search_models = configure_search_types()
|
||||
|
||||
configure_routes(app)
|
||||
client = TestClient(app)
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "plugin1" not in response.json()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_get_configured_types_with_no_content_config():
|
||||
# Arrange
|
||||
config.content_type = ContentConfig()
|
||||
state.SearchType = configure_search_types(config)
|
||||
|
||||
configure_routes(app)
|
||||
client = TestClient(app)
|
||||
configure_routes(fastapi_app)
|
||||
client = TestClient(fastapi_app)
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
@@ -166,8 +219,10 @@ def test_get_configured_types_with_no_content_config():
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
content_index.image = image_search.setup(
|
||||
content_config.image, search_models.image_search.image_encoder, regenerate=False
|
||||
@@ -180,7 +235,7 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
|
||||
|
||||
for query, expected_image_name in query_expected_image_pairs:
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={query}&n=1&t=image")
|
||||
response = client.get(f"/api/search?q={query}&n=1&t=image", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
@@ -192,43 +247,57 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true")
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# assert actual_data contains "Khoj via Emacs" entry
|
||||
|
||||
assert len(response.json()) == 1, "Expected only 1 result"
|
||||
search_result = response.json()[0]["entry"]
|
||||
assert "git clone https://github.com/khoj-ai/khoj" in search_result
|
||||
assert "git clone https://github.com/khoj-ai/khoj" in search_result, "Expected 'git clone' in search result"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_no_results(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||
user_query = quote("How to find my goat?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [], "Expected no results"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_only_filters(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user: KhojUser
|
||||
):
|
||||
# Arrange
|
||||
filters = [WordFilter(), FileFilter()]
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl,
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
text_search.setup(
|
||||
OrgToEntries,
|
||||
sample_org_data,
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
user=default_user,
|
||||
)
|
||||
user_query = quote('+"Emacs" file:"*.org"')
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
@@ -238,19 +307,15 @@ def test_notes_search_with_only_filters(
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_include_filter(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_include_filter(client, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
filters = [WordFilter()]
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search, regenerate=False, filters=filters
|
||||
)
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||
user_query = quote('How to git install application? +"Emacs"')
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
@@ -260,24 +325,20 @@ def test_notes_search_with_include_filter(
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_exclude_filter(
|
||||
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
|
||||
):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_notes_search_with_exclude_filter(client, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
filters = [WordFilter()]
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl,
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
text_search.setup(
|
||||
OrgToEntries,
|
||||
sample_org_data,
|
||||
content_config.org,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
filters=filters,
|
||||
user=default_user,
|
||||
)
|
||||
user_query = quote('How to git install application? -"clone"')
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
@@ -286,22 +347,56 @@ def test_notes_search_with_exclude_filter(
|
||||
assert "clone" not in search_result
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-token"} # Token for default_user2
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 403
|
||||
# assert actual response has no data as the default_user is different from the user making the query (anonymous)
|
||||
assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiUser):
|
||||
# Arrange
|
||||
token = api_user3.token
|
||||
headers = {"Authorization": "Bearer " + token}
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# assert actual response has no data as the default_user3, though other users have data
|
||||
assert len(response.json()) == 0
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def get_sample_files_data():
|
||||
return {
|
||||
"files": ("path/to/filename.org", "* practicing piano", "text/org"),
|
||||
"files": ("path/to/filename1.org", "** top 3 reasons why I moved to SF", "text/org"),
|
||||
"files": ("path/to/filename2.org", "* how to build a search engine", "text/org"),
|
||||
"files": ("path/to/filename.pdf", "Moore's law does not apply to consumer hardware", "application/pdf"),
|
||||
"files": ("path/to/filename1.pdf", "The sun is a ball of helium", "application/pdf"),
|
||||
"files": ("path/to/filename2.pdf", "Effect of sunshine on baseline human happiness", "application/pdf"),
|
||||
"files": ("path/to/filename.txt", "data,column,value", "text/plain"),
|
||||
"files": ("path/to/filename1.txt", "<html>my first web page</html>", "text/plain"),
|
||||
"files": ("path/to/filename2.txt", "2021-02-02 Journal Entry", "text/plain"),
|
||||
"files": ("path/to/filename.md", "# Notes from client call", "text/markdown"),
|
||||
"files": (
|
||||
"path/to/filename1.md",
|
||||
"## Studying anthropological records from the Fatimid caliphate",
|
||||
"text/markdown",
|
||||
return [
|
||||
("files", ("path/to/filename.org", "* practicing piano", "text/org")),
|
||||
("files", ("path/to/filename1.org", "** top 3 reasons why I moved to SF", "text/org")),
|
||||
("files", ("path/to/filename2.org", "* how to build a search engine", "text/org")),
|
||||
("files", ("path/to/filename.pdf", "Moore's law does not apply to consumer hardware", "application/pdf")),
|
||||
("files", ("path/to/filename1.pdf", "The sun is a ball of helium", "application/pdf")),
|
||||
("files", ("path/to/filename2.pdf", "Effect of sunshine on baseline human happiness", "application/pdf")),
|
||||
("files", ("path/to/filename.txt", "data,column,value", "text/plain")),
|
||||
("files", ("path/to/filename1.txt", "<html>my first web page</html>", "text/plain")),
|
||||
("files", ("path/to/filename2.txt", "2021-02-02 Journal Entry", "text/plain")),
|
||||
("files", ("path/to/filename.md", "# Notes from client call", "text/markdown")),
|
||||
(
|
||||
"files",
|
||||
("path/to/filename1.md", "## Studying anthropological records from the Fatimid caliphate", "text/markdown"),
|
||||
),
|
||||
"files": ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown"),
|
||||
}
|
||||
("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")),
|
||||
]
|
||||
|
||||
@@ -1,53 +1,12 @@
|
||||
# Standard Packages
|
||||
import re
|
||||
from datetime import datetime
|
||||
from math import inf
|
||||
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
|
||||
def test_date_filter():
|
||||
entries = [
|
||||
Entry(compiled="Entry with no date", raw="Entry with no date"),
|
||||
Entry(compiled="April Fools entry: 1984-04-01", raw="April Fools entry: 1984-04-01"),
|
||||
Entry(compiled="Entry with date:1984-04-02", raw="Entry with date:1984-04-02"),
|
||||
]
|
||||
|
||||
q_with_no_date_filter = "head tail"
|
||||
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2}
|
||||
|
||||
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == set()
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {2}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {1}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {2}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {1, 2}
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
|
||||
@@ -56,8 +15,8 @@ def test_extract_date_range():
|
||||
datetime(1984, 1, 5, 0, 0, 0).timestamp(),
|
||||
datetime(1984, 1, 7, 0, 0, 0).timestamp(),
|
||||
]
|
||||
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
|
||||
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [None, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
|
||||
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), None]
|
||||
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
|
||||
datetime(1984, 1, 1, 0, 0, 0).timestamp(),
|
||||
datetime(1984, 1, 2, 0, 0, 0).timestamp(),
|
||||
|
||||
@@ -6,97 +6,73 @@ from khoj.utils.rawconfig import Entry
|
||||
def test_no_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = "head tail"
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == False
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_file_filter_with_non_existent_file():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"nonexistent.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {}
|
||||
|
||||
|
||||
def test_single_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"file 1.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_file_filter_with_partial_match():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"1.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_file_filter_with_regex_match():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"*.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_multiple_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
entries = arrange_content()
|
||||
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_get_file_filter_terms():
|
||||
@@ -108,7 +84,7 @@ def test_get_file_filter_terms():
|
||||
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
|
||||
|
||||
# Assert
|
||||
assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"']
|
||||
assert filter_terms == ["file 1\\.org", "/path/to/dir/.*\\.org"]
|
||||
|
||||
|
||||
def arrange_content():
|
||||
|
||||
@@ -9,8 +9,7 @@ from faker import Faker
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.utils import state
|
||||
|
||||
from tests.helpers import ConversationFactory
|
||||
|
||||
SKIP_TESTS = True
|
||||
pytestmark = pytest.mark.skipif(
|
||||
@@ -23,7 +22,7 @@ fake = Faker()
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
def populate_chat_history(message_list, user):
|
||||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
for user_message, llm_message, context in message_list:
|
||||
@@ -33,14 +32,15 @@ def populate_chat_history(message_list):
|
||||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
)
|
||||
|
||||
# Update Conversation Metadata Logs in Application State
|
||||
state.processor_config.conversation.meta_log = conversation_log
|
||||
# Update Conversation Metadata Logs in Database
|
||||
ConversationFactory(user=user, conversation_log=conversation_log)
|
||||
|
||||
|
||||
# Tests
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat):
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
||||
@@ -56,13 +56,14 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_chat_history(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
@@ -78,7 +79,8 @@ def test_answer_from_chat_history(client_offline_chat):
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_currently_retrieved_content(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
@@ -88,7 +90,7 @@ def test_answer_from_currently_retrieved_content(client_offline_chat):
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"')
|
||||
@@ -101,7 +103,8 @@ def test_answer_from_currently_retrieved_content(client_offline_chat):
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
@@ -111,7 +114,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
|
||||
@@ -130,13 +133,14 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
|
||||
reason="Chat director not capable of answering this question yet because it requires extract_questions",
|
||||
)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
|
||||
@@ -154,14 +158,15 @@ def test_answer_from_chat_history_and_currently_retrieved_content(client_offline
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
|
||||
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true')
|
||||
@@ -177,11 +182,12 @@ def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_using_general_command(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_using_general_command(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
query = urllib.parse.quote("/general Where was Xi Li born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
||||
@@ -194,11 +200,12 @@ def test_answer_using_general_command(client_offline_chat):
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
query = urllib.parse.quote("/notes Where was Xi Li born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
||||
@@ -211,12 +218,13 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_using_file_filter(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_using_file_filter(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
|
||||
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8")
|
||||
@@ -229,11 +237,12 @@ def test_answer_using_file_filter(client_offline_chat):
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_not_known_using_notes_command(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_not_known_using_notes_command(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
query = urllib.parse.quote("/notes Where was Testatron born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
||||
@@ -247,6 +256,7 @@ def test_answer_not_known_using_notes_command(client_offline_chat):
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@freeze_time("2023-04-01")
|
||||
def test_answer_requires_current_date_awareness(client_offline_chat):
|
||||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||
@@ -265,6 +275,7 @@ def test_answer_requires_current_date_awareness(client_offline_chat):
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@freeze_time("2023-04-01")
|
||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
|
||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
@@ -280,14 +291,15 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(client_off
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(
|
||||
@@ -307,7 +319,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(client
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
|
||||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2):
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
@@ -328,14 +341,15 @@ def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
@@ -350,11 +364,12 @@ def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
|
||||
|
||||
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_chat_history_very_long(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_chat_history_very_long(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
|
||||
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
@@ -368,6 +383,7 @@ def test_answer_chat_history_very_long(client_offline_chat):
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_requires_multiple_independent_searches(client_offline_chat):
|
||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||
# Act
|
||||
|
||||
@@ -1,3 +1,14 @@
|
||||
# Standard Packages
|
||||
import numpy as np
|
||||
import psutil
|
||||
from scipy.stats import linregress
|
||||
import secrets
|
||||
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.embeddings import EmbeddingsModel
|
||||
from khoj.utils import helpers
|
||||
|
||||
|
||||
@@ -44,3 +55,31 @@ def test_lru_cache():
|
||||
cache["b"] # accessing 'b' makes it the most recently used item
|
||||
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
|
||||
assert cache == {"b": 2, "d": 4}
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Memory leak exists on GPU, MPS devices")
|
||||
def test_encode_docs_memory_leak():
|
||||
# Arrange
|
||||
iterations = 50
|
||||
batch_size = 20
|
||||
embeddings_model = EmbeddingsModel()
|
||||
memory_usage_trend = []
|
||||
device = f"{helpers.get_device()}".upper()
|
||||
|
||||
# Act
|
||||
# Encode random strings repeatedly and record memory usage trend
|
||||
for iteration in range(iterations):
|
||||
random_docs = [" ".join(secrets.token_hex(5) for _ in range(10)) for _ in range(batch_size)]
|
||||
a = [embeddings_model.embed_documents(random_docs)]
|
||||
memory_usage_trend += [psutil.Process().memory_info().rss / (1024 * 1024)]
|
||||
print(f"{iteration:02d}, {memory_usage_trend[-1]:.2f}", flush=True)
|
||||
|
||||
# Calculate slope of line fitting memory usage history
|
||||
memory_usage_trend = np.array(memory_usage_trend)
|
||||
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
|
||||
print(f"Memory usage increased at ~{slope:.2f} MB per iteration on {device}")
|
||||
|
||||
# Assert
|
||||
# If slope is positive memory utilization is increasing
|
||||
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
|
||||
assert slope < 2, f"Memory leak suspected on {device}. Memory usage increased at ~{slope:.2f} MB per iteration"
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
# Internal Packages
|
||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
def test_process_entries_from_single_input_jsonl(tmp_path):
|
||||
"Convert multiple jsonl entries from single file to entries."
|
||||
# Arrange
|
||||
input_jsonl = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}
|
||||
{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}
|
||||
"""
|
||||
input_jsonl_file = create_file(tmp_path, input_jsonl)
|
||||
|
||||
# Act
|
||||
# Process Each Entry from All Notes Files
|
||||
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file])
|
||||
entries = list(map(Entry.from_dict, input_jsons))
|
||||
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert output_jsonl == input_jsonl
|
||||
|
||||
|
||||
def test_process_entries_from_multiple_input_jsonls(tmp_path):
|
||||
"Convert multiple jsonl entries from single file to entries."
|
||||
# Arrange
|
||||
input_jsonl_1 = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}"""
|
||||
input_jsonl_2 = """{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}"""
|
||||
input_jsonl_file_1 = create_file(tmp_path, input_jsonl_1, filename="input1.jsonl")
|
||||
input_jsonl_file_2 = create_file(tmp_path, input_jsonl_2, filename="input2.jsonl")
|
||||
|
||||
# Act
|
||||
# Process Each Entry from All Notes Files
|
||||
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file_1, input_jsonl_file_2])
|
||||
entries = list(map(Entry.from_dict, input_jsons))
|
||||
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert output_jsonl == f"{input_jsonl_1}\n{input_jsonl_2}\n"
|
||||
|
||||
|
||||
def test_get_jsonl_files(tmp_path):
|
||||
"Ensure JSONL files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
# Include via input-filter globs
|
||||
group1_file1 = create_file(tmp_path, filename="group1-file1.jsonl")
|
||||
group1_file2 = create_file(tmp_path, filename="group1-file2.jsonl")
|
||||
group2_file1 = create_file(tmp_path, filename="group2-file1.jsonl")
|
||||
group2_file2 = create_file(tmp_path, filename="group2-file2.jsonl")
|
||||
# Include via input-file field
|
||||
file1 = create_file(tmp_path, filename="notes.jsonl")
|
||||
# Not included by any filter
|
||||
create_file(tmp_path, filename="not-included-jsonl.jsonl")
|
||||
create_file(tmp_path, filename="not-included-text.txt")
|
||||
|
||||
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
|
||||
|
||||
# Setup input-files, input-filters
|
||||
input_files = [tmp_path / "notes.jsonl"]
|
||||
input_filter = [tmp_path / "group1*.jsonl", tmp_path / "group2*.jsonl"]
|
||||
|
||||
# Act
|
||||
extracted_org_files = JsonlToJsonl.get_jsonl_files(input_files, input_filter)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_org_files) == 5
|
||||
assert extracted_org_files == expected_files
|
||||
|
||||
|
||||
# Helper Functions
|
||||
def create_file(tmp_path, entry=None, filename="test.jsonl"):
|
||||
jsonl_file = tmp_path / filename
|
||||
jsonl_file.touch()
|
||||
if entry:
|
||||
jsonl_file.write_text(entry)
|
||||
return jsonl_file
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
import os
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.processor.markdown.markdown_to_entries import MarkdownToEntries
|
||||
from khoj.utils.fs_syncer import get_markdown_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
|
||||
@@ -23,11 +23,11 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entry_nodes, file_to_entries = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
|
||||
entry_nodes, file_to_entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
|
||||
MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)
|
||||
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
|
||||
MarkdownToEntries.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
@@ -52,11 +52,11 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
|
||||
entries, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
|
||||
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)
|
||||
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
|
||||
MarkdownToEntries.convert_markdown_entries_to_maps(entries, entry_to_file_map)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
@@ -81,11 +81,11 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entry_strings, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
|
||||
entries = MarkdownToJsonl.convert_markdown_entries_to_maps(entry_strings, entry_to_file_map)
|
||||
entry_strings, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
||||
entries = MarkdownToEntries.convert_markdown_entries_to_maps(entry_strings, entry_to_file_map)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
|
||||
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(entries)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
@@ -144,7 +144,7 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries, _ = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
|
||||
entries, _ = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
111
tests/test_multiple_users.py
Normal file
111
tests/test_multiple_users.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Standard Modules
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from urllib.parse import quote
|
||||
import pytest
|
||||
|
||||
# External Packages
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI, UploadFile
|
||||
from io import BytesIO
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_routes, configure_search_types
|
||||
from khoj.utils import state
|
||||
from khoj.utils.state import search_models, content_index, config
|
||||
from khoj.search_type import text_search, image_search
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.processor.org_mode.org_to_entries import OrgToEntries
|
||||
from database.models import KhojUser, KhojApiUser
|
||||
from database.adapters import EntryAdapters
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_search_for_user2_returns_empty(client, api_user2: KhojApiUser):
|
||||
token = api_user2.token
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
for content_type in ["all", "org", "markdown", "pdf", "github", "notion", "plaintext"]:
|
||||
# Act
|
||||
response = client.get(f"/api/search?q=random&t={content_type}", headers=headers)
|
||||
# Assert
|
||||
assert response.text == "[]"
|
||||
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update_with_user2(client, api_user2: KhojApiUser):
|
||||
# Arrange
|
||||
files = get_sample_files_data()
|
||||
source_file_symbol = set([f[1][0] for f in files])
|
||||
|
||||
headers = {"Authorization": f"Bearer {api_user2.token}"}
|
||||
update_response = client.post("/api/v1/index/update", files=files, headers=headers)
|
||||
search_response = client.get("/api/search?q=hardware&t=all", headers=headers)
|
||||
results = search_response.json()
|
||||
|
||||
# Assert
|
||||
assert update_response.status_code == 200
|
||||
assert len(results) == 5
|
||||
for result in results:
|
||||
assert result["additional"]["file"] in source_file_symbol
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update_with_user2_inaccessible_user1(client, api_user2: KhojApiUser, api_user: KhojApiUser):
|
||||
# Arrange
|
||||
files = get_sample_files_data()
|
||||
source_file_symbol = set([f[1][0] for f in files])
|
||||
|
||||
headers = {"Authorization": f"Bearer {api_user2.token}"}
|
||||
update_response = client.post("/api/v1/index/update", files=files, headers=headers)
|
||||
|
||||
# Act
|
||||
headers = {"Authorization": f"Bearer {api_user.token}"}
|
||||
search_response = client.get("/api/search?q=hardware&t=all", headers=headers)
|
||||
results = search_response.json()
|
||||
|
||||
# Assert
|
||||
assert update_response.status_code == 200
|
||||
assert len(results) == 4
|
||||
for result in results:
|
||||
assert result["additional"]["file"] not in source_file_symbol
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
|
||||
# Arrange
|
||||
headers = {"Authorization": "Bearer kk-token"} # Token for default_user2
|
||||
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 403
|
||||
# assert actual response has no data as the default_user is different from the user making the query (anonymous)
|
||||
assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden"
|
||||
|
||||
|
||||
def get_sample_files_data():
|
||||
return [
|
||||
("files", ("path/to/filename.org", "* practicing piano", "text/org")),
|
||||
("files", ("path/to/filename1.org", "** top 3 reasons why I moved to SF", "text/org")),
|
||||
("files", ("path/to/filename2.org", "* how to build a search engine", "text/org")),
|
||||
("files", ("path/to/filename.pdf", "Moore's law does not apply to consumer hardware", "application/pdf")),
|
||||
("files", ("path/to/filename1.pdf", "The sun is a ball of helium", "application/pdf")),
|
||||
("files", ("path/to/filename2.pdf", "Effect of sunshine on baseline human happiness", "application/pdf")),
|
||||
("files", ("path/to/filename.txt", "data,column,value", "text/plain")),
|
||||
("files", ("path/to/filename1.txt", "<html>my first web page</html>", "text/plain")),
|
||||
("files", ("path/to/filename2.txt", "2021-02-02 Journal Entry", "text/plain")),
|
||||
("files", ("path/to/filename.md", "# Notes from client call", "text/markdown")),
|
||||
(
|
||||
"files",
|
||||
("path/to/filename1.md", "## Studying anthropological records from the Fatimid caliphate", "text/markdown"),
|
||||
),
|
||||
("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")),
|
||||
]
|
||||
@@ -9,8 +9,8 @@ from khoj.processor.conversation import prompts
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.utils import state
|
||||
|
||||
from tests.helpers import ConversationFactory
|
||||
from database.models import KhojUser
|
||||
|
||||
# Initialize variables for tests
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
@@ -23,7 +23,7 @@ if api_key is None:
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
def populate_chat_history(message_list, user=None):
|
||||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
for user_message, gpt_message, context in message_list:
|
||||
@@ -33,13 +33,14 @@ def populate_chat_history(message_list):
|
||||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
)
|
||||
|
||||
# Update Conversation Metadata Logs in Application State
|
||||
state.processor_config.conversation.meta_log = conversation_log
|
||||
# Update Conversation Metadata Logs in Database
|
||||
ConversationFactory(user=user, conversation_log=conversation_log)
|
||||
|
||||
|
||||
# Tests
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
||||
@@ -54,14 +55,15 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history(chat_client):
|
||||
def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
@@ -76,8 +78,9 @@ def test_answer_from_chat_history(chat_client):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_currently_retrieved_content(chat_client):
|
||||
def test_answer_from_currently_retrieved_content(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
@@ -87,7 +90,7 @@ def test_answer_from_currently_retrieved_content(chat_client):
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
|
||||
@@ -99,8 +102,9 @@ def test_answer_from_currently_retrieved_content(chat_client):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_no_background, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
@@ -110,10 +114,10 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
||||
response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
@@ -125,14 +129,15 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
||||
@@ -148,15 +153,16 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
|
||||
def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser):
|
||||
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
|
||||
@@ -171,12 +177,13 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_using_general_command(chat_client):
|
||||
def test_answer_using_general_command(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
query = urllib.parse.quote("/general Where was Xi Li born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||
@@ -188,12 +195,13 @@ def test_answer_using_general_command(chat_client):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_retrieved_content_using_notes_command(chat_client):
|
||||
def test_answer_from_retrieved_content_using_notes_command(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
query = urllib.parse.quote("/notes Where was Xi Li born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||
@@ -205,24 +213,26 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_not_known_using_notes_command(chat_client):
|
||||
def test_answer_not_known_using_notes_command(chat_client_no_background, default_user2: KhojUser):
|
||||
# Arrange
|
||||
query = urllib.parse.quote("/notes Where was Testatron born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||
response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response_message == prompts.no_notes_found.format()
|
||||
assert response_message == prompts.no_entries_found.format()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("2023-04-01")
|
||||
def test_answer_requires_current_date_awareness(chat_client):
|
||||
@@ -240,11 +250,13 @@ def test_answer_requires_current_date_awareness(chat_client):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("2023-04-01")
|
||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
|
||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
# Act
|
||||
|
||||
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
@@ -254,15 +266,16 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client):
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(
|
||||
@@ -280,10 +293,12 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
||||
|
||||
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
@@ -292,6 +307,8 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
|
||||
"which one is",
|
||||
"which of namita's sons",
|
||||
"the birth order",
|
||||
"provide more context",
|
||||
"provide me with more context",
|
||||
]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
@@ -301,15 +318,16 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_in_chat_history_beyond_lookback_window(chat_client):
|
||||
def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
@@ -324,6 +342,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_requires_multiple_independent_searches(chat_client):
|
||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||
@@ -340,10 +359,12 @@ def test_answer_requires_multiple_independent_searches(chat_client):
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_using_file_filter(chat_client):
|
||||
"Chat should be able to use search filters in the query"
|
||||
# Act
|
||||
query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"')
|
||||
|
||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import json
|
||||
import os
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.text_to_jsonl import TextToJsonl
|
||||
from khoj.processor.org_mode.org_to_entries import OrgToEntries
|
||||
from khoj.processor.text_to_entries import TextToEntries
|
||||
from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.utils.fs_syncer import get_org_files
|
||||
@@ -29,9 +29,9 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
|
||||
for index_heading_entries in [True, False]:
|
||||
# Act
|
||||
# Extract entries into jsonl from specified Org files
|
||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||
OrgToJsonl.convert_org_nodes_to_entries(
|
||||
*OrgToJsonl.extract_org_entries(org_files=data), index_heading_entries=index_heading_entries
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
|
||||
OrgToEntries.convert_org_nodes_to_entries(
|
||||
*OrgToEntries.extract_org_entries(org_files=data), index_heading_entries=index_heading_entries
|
||||
)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
@@ -59,12 +59,12 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=data)
|
||||
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Split each entry from specified Org files by max words
|
||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||
TextToJsonl.split_entries_by_max_tokens(
|
||||
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
|
||||
TextToEntries.split_entries_by_max_tokens(
|
||||
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
|
||||
)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
@@ -86,7 +86,7 @@ def test_entry_split_drops_large_words():
|
||||
|
||||
# Act
|
||||
# Split entry by max words and drop words larger than max word length
|
||||
processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0]
|
||||
processed_entry = TextToEntries.split_entries_by_max_tokens([entry], max_word_length=5)[0]
|
||||
|
||||
# Assert
|
||||
# "Heading" dropped from compiled version because its over the set max word limit
|
||||
@@ -109,11 +109,11 @@ def test_entry_with_body_to_jsonl(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=data)
|
||||
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
|
||||
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map)
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
|
||||
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
@@ -136,11 +136,11 @@ Intro text
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=data)
|
||||
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(entries)
|
||||
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
@@ -160,11 +160,11 @@ def test_file_with_no_headings_to_jsonl(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=data)
|
||||
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(entries)
|
||||
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
@@ -224,7 +224,7 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, _ = OrgToJsonl.extract_org_entries(org_files=data)
|
||||
entries, _ = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import os
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
|
||||
from khoj.processor.pdf.pdf_to_entries import PdfToEntries
|
||||
|
||||
from khoj.utils.fs_syncer import get_pdf_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
@@ -18,11 +18,11 @@ def test_single_page_pdf_to_jsonl():
|
||||
pdf_bytes = f.read()
|
||||
|
||||
data = {"tests/data/pdf/singlepage.pdf": pdf_bytes}
|
||||
entries, entry_to_file_map = PdfToJsonl.extract_pdf_entries(pdf_files=data)
|
||||
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Process Each Entry from All Pdf Files
|
||||
jsonl_string = PdfToJsonl.convert_pdf_maps_to_jsonl(
|
||||
PdfToJsonl.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
||||
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
|
||||
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
@@ -38,11 +38,11 @@ def test_multi_page_pdf_to_jsonl():
|
||||
pdf_bytes = f.read()
|
||||
|
||||
data = {"tests/data/pdf/multipage.pdf": pdf_bytes}
|
||||
entries, entry_to_file_map = PdfToJsonl.extract_pdf_entries(pdf_files=data)
|
||||
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Process Each Entry from All Pdf Files
|
||||
jsonl_string = PdfToJsonl.convert_pdf_maps_to_jsonl(
|
||||
PdfToJsonl.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
||||
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
|
||||
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
@@ -50,6 +50,23 @@ def test_multi_page_pdf_to_jsonl():
|
||||
assert len(jsonl_data) == 6
|
||||
|
||||
|
||||
def test_ocr_page_pdf_to_jsonl():
|
||||
"Convert multiple pages from single PDF file to jsonl."
|
||||
# Act
|
||||
# Extract Entries from specified Pdf files
|
||||
with open("tests/data/pdf/ocr_samples.pdf", "rb") as f:
|
||||
pdf_bytes = f.read()
|
||||
|
||||
data = {"tests/data/pdf/ocr_samples.pdf": pdf_bytes}
|
||||
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Process Each Entry from All Pdf Files
|
||||
entries = PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
||||
|
||||
assert len(entries) == 1
|
||||
assert "playing on a strip of marsh" in entries[0].raw
|
||||
|
||||
|
||||
def test_get_pdf_files(tmp_path):
|
||||
"Ensure Pdf files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
@@ -6,7 +6,8 @@ from pathlib import Path
|
||||
# Internal Packages
|
||||
from khoj.utils.fs_syncer import get_plaintext_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||
from database.models import LocalPlaintextConfig, KhojUser
|
||||
|
||||
|
||||
def test_plaintext_file(tmp_path):
|
||||
@@ -26,14 +27,14 @@ def test_plaintext_file(tmp_path):
|
||||
f"{plaintextfile}": entry,
|
||||
}
|
||||
|
||||
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(entry_to_file_map=data)
|
||||
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(entry_to_file_map=data)
|
||||
|
||||
# Convert each entry.file to absolute path to make them JSON serializable
|
||||
for map in maps:
|
||||
map.file = str(Path(map.file).absolute())
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = PlaintextToJsonl.convert_entries_to_jsonl(maps)
|
||||
jsonl_string = PlaintextToEntries.convert_entries_to_jsonl(maps)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
@@ -91,14 +92,15 @@ def test_get_plaintext_files(tmp_path):
|
||||
assert set(extracted_plaintext_files.keys()) == set(expected_files)
|
||||
|
||||
|
||||
def test_parse_html_plaintext_file(content_config):
|
||||
def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
|
||||
"Ensure HTML files are parsed correctly"
|
||||
# Arrange
|
||||
# Setup input-files, input-filters
|
||||
extracted_plaintext_files = get_plaintext_files(content_config.plaintext)
|
||||
config = LocalPlaintextConfig.objects.filter(user=default_user).first()
|
||||
extracted_plaintext_files = get_plaintext_files(config=config)
|
||||
|
||||
# Act
|
||||
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files)
|
||||
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(extracted_plaintext_files)
|
||||
|
||||
# Assert
|
||||
assert len(maps) == 1
|
||||
@@ -1,25 +1,27 @@
|
||||
# System Packages
|
||||
import logging
|
||||
import locale
|
||||
from pathlib import Path
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.state import content_index, search_models
|
||||
from khoj.search_type import text_search
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.fs_syncer import get_org_files
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.processor.org_mode.org_to_entries import OrgToEntries
|
||||
from khoj.processor.github.github_to_entries import GithubToEntries
|
||||
from khoj.utils.fs_syncer import collect_files, get_org_files
|
||||
from database.models import LocalOrgConfig, KhojUser, Entry, GithubConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig):
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: LocalOrgConfig):
|
||||
# Arrange
|
||||
# Ensure file mentioned in org.input-files is missing
|
||||
single_new_file = Path(org_config_with_only_new_file.input_files[0])
|
||||
@@ -32,98 +34,148 @@ def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_n
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path: Path):
|
||||
@pytest.mark.django_db
|
||||
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser):
|
||||
# Arrange
|
||||
orgfile = tmp_path / "directory.org" / "file.org"
|
||||
orgfile.parent.mkdir()
|
||||
with open(orgfile, "w") as f:
|
||||
f.write("* Heading\n- List item\n")
|
||||
org_content_config = TextContentConfig(
|
||||
input_filter=[f"{tmp_path}/**/*"], compressed_jsonl="test.jsonl", embeddings_file="test.pt"
|
||||
|
||||
LocalOrgConfig.objects.create(
|
||||
input_filter=[f"{tmp_path}/**/*"],
|
||||
input_files=None,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
# Act
|
||||
# should not raise IsADirectoryError and return orgfile
|
||||
assert get_org_files(org_content_config) == {f"{orgfile}": "* Heading\n- List item\n"}
|
||||
org_files = collect_files(user=default_user)["org"]
|
||||
|
||||
# Assert
|
||||
# should return orgfile and not raise IsADirectoryError
|
||||
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_search_setup_with_empty_file_raises_error(
|
||||
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_empty_file_creates_no_entries(
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with pytest.raises(ValueError, match=r"^No valid entries found*"):
|
||||
text_search.setup(OrgToJsonl, data, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
|
||||
# Act
|
||||
# Regenerate notes embeddings during asymmetric setup
|
||||
notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert len(notes_model.entries) == 10
|
||||
assert len(notes_model.corpus_embeddings) == 10
|
||||
assert "Deleted 3 entries. Created 0 new entries for user " in caplog.records[-1].message
|
||||
verify_embeddings(0, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog):
|
||||
@pytest.mark.django_db
|
||||
def test_text_indexer_deletes_embedding_before_regenerate(
|
||||
content_config: ContentConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
caplog.set_level(logging.INFO, logger="khoj")
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
data = get_org_files(content_config.org)
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert "Deleting all entries for file type org" in caplog.text
|
||||
assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# Generate initial notes embeddings during asymmetric setup
|
||||
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
|
||||
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
assert "Creating index from scratch." in initial_logs
|
||||
assert "Creating index from scratch." not in final_logs
|
||||
assert "Deleting all entries for file type org" in initial_logs
|
||||
assert "Deleting all entries for file type org" not in final_logs
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.anyio
|
||||
async def test_text_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# @pytest.mark.asyncio
|
||||
async def test_text_search(search_config: SearchConfig):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
|
||||
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
|
||||
content_index.org = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
default_user = await KhojUser.objects.acreate(
|
||||
username="test_user", password="test_password", email="test@example.com"
|
||||
)
|
||||
org_config = await LocalOrgConfig.objects.acreate(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/org/*.org"],
|
||||
index_heading_entries=False,
|
||||
user=default_user,
|
||||
)
|
||||
data = get_org_files(org_config)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
text_search.setup,
|
||||
OrgToEntries,
|
||||
data,
|
||||
True,
|
||||
True,
|
||||
default_user,
|
||||
)
|
||||
|
||||
query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
hits, entries = await text_search.query(
|
||||
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
|
||||
)
|
||||
|
||||
results = text_search.collate_results(hits, entries, count=1)
|
||||
hits = await text_search.query(default_user, query)
|
||||
results = text_search.collate_results(hits)
|
||||
results = sorted(results, key=lambda x: float(x.score))[:1]
|
||||
|
||||
# Assert
|
||||
# search results should contain "git clone" entry
|
||||
search_result = results[0].entry
|
||||
assert "git clone" in search_result
|
||||
assert "git clone" in search_result, 'search result did not contain "git clone" entry'
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
||||
@pytest.mark.django_db
|
||||
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||
max_tokens = 256
|
||||
@@ -137,47 +189,45 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
|
||||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model after adding new org-mode file
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
|
||||
# Assert
|
||||
# verify newly added org-mode entry is split by max tokens
|
||||
assert len(initial_notes_model.entries) == 2
|
||||
assert len(initial_notes_model.corpus_embeddings) == 2
|
||||
assert (
|
||||
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||
), "new entry not split by max tokens"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
|
||||
@pytest.mark.django_db
|
||||
def test_entry_chunking_by_max_tokens_not_full_corpus(
|
||||
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
# Insert org-mode entry with size exceeding max token limit to new org file
|
||||
data = {
|
||||
"readme.org": """
|
||||
* Khoj
|
||||
/Allow natural language search on user content like notes, images using transformer based models/
|
||||
/Allow natural language search on user content like notes, images using transformer based models/
|
||||
|
||||
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
|
||||
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
|
||||
|
||||
** Dependencies
|
||||
- Python3
|
||||
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
|
||||
- Python3
|
||||
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
|
||||
|
||||
** Install
|
||||
#+begin_src shell
|
||||
git clone https://github.com/khoj-ai/khoj && cd khoj
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
#+end_src"""
|
||||
#+begin_src shell
|
||||
git clone https://github.com/khoj-ai/khoj && cd khoj
|
||||
conda env create -f environment.yml
|
||||
conda activate khoj
|
||||
#+end_src"""
|
||||
}
|
||||
text_search.setup(
|
||||
OrgToJsonl,
|
||||
OrgToEntries,
|
||||
data,
|
||||
org_config_with_only_new_file,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
max_tokens = 256
|
||||
@@ -191,64 +241,58 @@ def test_entry_chunking_by_max_tokens_not_full_corpus(
|
||||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model after adding new org-mode file
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl,
|
||||
data,
|
||||
org_config_with_only_new_file,
|
||||
search_models.text_search.bi_encoder,
|
||||
regenerate=False,
|
||||
full_corpus=False,
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(
|
||||
OrgToEntries,
|
||||
data,
|
||||
regenerate=False,
|
||||
full_corpus=False,
|
||||
user=default_user,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# verify newly added org-mode entry is split by max tokens
|
||||
assert len(initial_notes_model.entries) == 5
|
||||
assert len(initial_notes_model.corpus_embeddings) == 5
|
||||
assert (
|
||||
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||
), "new entry not split by max tokens"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_regenerate_index_with_new_entry(
|
||||
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
|
||||
content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
|
||||
assert len(initial_notes_model.entries) == 10
|
||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
initial_data = get_org_files(org_config)
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
content_config.org.input_files = [f"{new_org_file}"]
|
||||
org_config.input_files = [f"{new_org_file}"]
|
||||
with open(new_org_file, "w") as f:
|
||||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||
|
||||
data = get_org_files(content_config.org)
|
||||
final_data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
regenerated_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
assert len(regenerated_notes_model.entries) == 11
|
||||
assert len(regenerated_notes_model.corpus_embeddings) == 11
|
||||
|
||||
# verify new entry appended to index, without disrupting order or content of existing entries
|
||||
error_details = compare_index(initial_notes_model, regenerated_notes_model)
|
||||
if error_details:
|
||||
pytest.fail(error_details, False)
|
||||
|
||||
# Cleanup
|
||||
# reset input_files in config to empty list
|
||||
content_config.org.input_files = []
|
||||
assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
|
||||
assert "Deleted 10 entries. Created 11 new entries for user " in final_logs
|
||||
verify_embeddings(11, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_duplicate_entries_in_stable_order(
|
||||
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
@@ -261,31 +305,30 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
# load embeddings, entries, notes model after adding new org-mode file
|
||||
initial_index = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
# generate embeddings, entries, notes model from scratch after adding new org-mode file
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# update embeddings, entries, notes model after adding new org-mode file
|
||||
updated_index = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
# update embeddings, entries, notes model with no new changes
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert len(initial_index.entries) == len(updated_index.entries) == 1
|
||||
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1
|
||||
assert "Deleted 3 entries. Created 1 new entries for user " in initial_logs
|
||||
assert "Deleted 0 entries. Created 0 new entries for user " in final_logs
|
||||
|
||||
# verify the same entry is added even when there are multiple duplicate entries
|
||||
error_details = compare_index(initial_index, updated_index)
|
||||
if error_details:
|
||||
pytest.fail(error_details)
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
||||
@@ -293,101 +336,84 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
|
||||
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}{new_entry} -- Tatooine")
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
initial_index = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
|
||||
)
|
||||
initial_data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# update embeddings, entries, notes model after removing an entry from the org file
|
||||
with open(new_file_to_index, "w") as f:
|
||||
f.write(f"{new_entry}")
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
final_data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
updated_index = text_search.setup(
|
||||
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
|
||||
)
|
||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert len(initial_index.entries) == len(updated_index.entries) + 1
|
||||
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
|
||||
assert "Deleted 3 entries. Created 2 new entries for user " in initial_logs
|
||||
assert "Deleted 1 entries. Created 0 new entries for user " in final_logs
|
||||
|
||||
# verify the same entry is added even when there are multiple duplicate entries
|
||||
error_details = compare_index(updated_index, initial_index)
|
||||
if error_details:
|
||||
pytest.fail(error_details)
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
data = get_org_files(content_config.org)
|
||||
initial_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
|
||||
)
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
with open(new_org_file, "w") as f:
|
||||
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
|
||||
f.write(new_entry)
|
||||
|
||||
data = get_org_files(content_config.org)
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# update embeddings, entries with the newly added note
|
||||
content_config.org.input_files = [f"{new_org_file}"]
|
||||
final_notes_model = text_search.setup(
|
||||
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1
|
||||
assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1
|
||||
|
||||
# verify new entry appended to index, without disrupting order or content of existing entries
|
||||
error_details = compare_index(initial_notes_model, final_notes_model)
|
||||
if error_details:
|
||||
pytest.fail(error_details, False)
|
||||
|
||||
# Cleanup
|
||||
# reset input_files in config to empty list
|
||||
content_config.org.input_files = []
|
||||
assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
|
||||
assert "Deleted 0 entries. Created 1 new entries for user " in final_logs
|
||||
verify_embeddings(11, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
|
||||
def test_text_search_setup_github(content_config: ContentConfig, search_models: SearchModels):
|
||||
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
|
||||
# Arrange
|
||||
github_config = GithubConfig.objects.filter(user=default_user).first()
|
||||
|
||||
# Act
|
||||
# Regenerate github embeddings to test asymmetric setup without caching
|
||||
github_model = text_search.setup(
|
||||
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
|
||||
text_search.setup(
|
||||
GithubToEntries,
|
||||
{},
|
||||
regenerate=True,
|
||||
user=default_user,
|
||||
config=github_config,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(github_model.entries) > 1
|
||||
embeddings = Entry.objects.filter(user=default_user, file_type="github").count()
|
||||
assert embeddings > 1
|
||||
|
||||
|
||||
def compare_index(initial_notes_model, final_notes_model):
|
||||
mismatched_entries, mismatched_embeddings = [], []
|
||||
for index in range(len(initial_notes_model.entries)):
|
||||
if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
|
||||
mismatched_entries.append(index)
|
||||
|
||||
# verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
|
||||
for index in range(len(initial_notes_model.corpus_embeddings)):
|
||||
if not initial_notes_model.corpus_embeddings[index].allclose(final_notes_model.corpus_embeddings[index]):
|
||||
mismatched_embeddings.append(index)
|
||||
|
||||
error_details = ""
|
||||
if mismatched_entries:
|
||||
mismatched_entries_str = ",".join(map(str, mismatched_entries))
|
||||
error_details += f"Entries at {mismatched_entries_str} not equal\n"
|
||||
if mismatched_embeddings:
|
||||
mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
|
||||
error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
|
||||
|
||||
return error_details
|
||||
def verify_embeddings(expected_count, user):
|
||||
embeddings = Entry.objects.filter(user=user, file_type="org").count()
|
||||
assert embeddings == expected_count
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_no_word_filter():
|
||||
@@ -21,54 +22,44 @@ def test_no_word_filter():
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_word_exclude_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
q_with_exclude_filter = 'head -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(q_with_exclude_filter)
|
||||
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_word_include_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
query_with_include_filter = 'head +"include_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(query_with_include_filter)
|
||||
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {2, 3}
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_word_include_and_exclude_filter():
|
||||
# Arrange
|
||||
word_filter = WordFilter()
|
||||
entries = arrange_content()
|
||||
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
|
||||
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == "head tail"
|
||||
assert entry_indices == {2}
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user