Resolve merge conflicts

This commit is contained in:
sabaimran
2023-11-19 12:57:55 -08:00
172 changed files with 9190 additions and 4823 deletions

View File

@@ -1,61 +1,65 @@
# External Packages
import os
from copy import deepcopy
from fastapi.testclient import TestClient
from pathlib import Path
import pytest
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI
import os
from fastapi import FastAPI
# Internal Packages
from khoj.main import app
from khoj.configure import configure_processor, configure_routes, configure_search_types
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.configure import configure_routes, configure_search_types, configure_middleware
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchModels
from khoj.utils.constants import web_directory
from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import (
ContentConfig,
ConversationProcessorConfig,
OfflineChatProcessorConfig,
OpenAIProcessorConfig,
ProcessorConfig,
TextContentConfig,
GithubContentConfig,
GithubRepoConfig,
ImageContentConfig,
SearchConfig,
TextSearchConfig,
ImageSearchConfig,
)
from khoj.utils import state, fs_syncer
from khoj.routers.indexer import configure_content
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.processor.org_mode.org_to_entries import OrgToEntries
from database.models import (
KhojApiUser,
LocalOrgConfig,
LocalMarkdownConfig,
LocalPlaintextConfig,
GithubConfig,
KhojUser,
GithubRepoConfig,
)
from tests.helpers import (
UserFactory,
ChatModelOptionsFactory,
OpenAIProcessorConversationConfigFactory,
OfflineChatProcessorConversationConfigFactory,
UserConversationProcessorConfigFactory,
SubscriptionFactory,
)
@pytest.fixture(autouse=True)
def enable_db_access_for_all_tests(db):
pass
@pytest.fixture(scope="session")
def search_config() -> SearchConfig:
state.embeddings_model = EmbeddingsModel()
state.cross_encoder_model = CrossEncoderModel()
model_dir = resolve_absolute_path("~/.khoj/search")
model_dir.mkdir(parents=True, exist_ok=True)
search_config = SearchConfig()
search_config.symmetric = TextSearchConfig(
encoder="sentence-transformers/all-MiniLM-L6-v2",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "symmetric/",
encoder_type=None,
)
search_config.asymmetric = TextSearchConfig(
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "asymmetric/",
encoder_type=None,
)
search_config.image = ImageSearchConfig(
encoder="sentence-transformers/clip-ViT-B-32",
model_directory=model_dir / "image/",
@@ -65,17 +69,102 @@ def search_config() -> SearchConfig:
return search_config
@pytest.mark.django_db
@pytest.fixture
def default_user():
user = UserFactory()
SubscriptionFactory(user=user)
return user
@pytest.mark.django_db
@pytest.fixture
def default_user2():
if KhojUser.objects.filter(username="default").exists():
return KhojUser.objects.get(username="default")
user = KhojUser.objects.create(
username="default",
email="default@example.com",
password="default",
)
SubscriptionFactory(user=user)
return user
@pytest.mark.django_db
@pytest.fixture
def default_user3():
"""
This user should not have any data associated with it
"""
if KhojUser.objects.filter(username="default3").exists():
return KhojUser.objects.get(username="default3")
user = KhojUser.objects.create(
username="default3",
email="default3@example.com",
password="default3",
)
SubscriptionFactory(user=user)
return user
@pytest.mark.django_db
@pytest.fixture
def api_user(default_user):
if KhojApiUser.objects.filter(user=default_user).exists():
return KhojApiUser.objects.get(user=default_user)
return KhojApiUser.objects.create(
user=default_user,
name="api-key",
token="kk-secret",
)
@pytest.mark.django_db
@pytest.fixture
def api_user2(default_user2):
if KhojApiUser.objects.filter(user=default_user2).exists():
return KhojApiUser.objects.get(user=default_user2)
return KhojApiUser.objects.create(
user=default_user2,
name="api-key",
token="kk-diff-secret",
)
@pytest.mark.django_db
@pytest.fixture
def api_user3(default_user3):
if KhojApiUser.objects.filter(user=default_user3).exists():
return KhojApiUser.objects.get(user=default_user3)
return KhojApiUser.objects.create(
user=default_user3,
name="api-key",
token="kk-diff-secret-3",
)
@pytest.fixture(scope="session")
def search_models(search_config: SearchConfig):
search_models = SearchModels()
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
search_models.image_search = image_search.initialize_model(search_config.image)
return search_models
@pytest.fixture(scope="session")
def content_config(tmp_path_factory, search_models: SearchModels, search_config: SearchConfig):
@pytest.fixture
def anyio_backend():
return "asyncio"
@pytest.mark.django_db
@pytest.fixture(scope="function")
def content_config(tmp_path_factory, search_models: SearchModels, default_user: KhojUser):
content_dir = tmp_path_factory.mktemp("content")
# Generate Image Embeddings from Test Images
@@ -90,217 +179,188 @@ def content_config(tmp_path_factory, search_models: SearchModels, search_config:
image_search.setup(content_config.image, search_models.image_search.image_encoder, regenerate=False)
# Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig(
LocalOrgConfig.objects.create(
input_files=None,
input_filter=["tests/data/org/*.org"],
compressed_jsonl=content_dir.joinpath("notes.jsonl.gz"),
embeddings_file=content_dir.joinpath("note_embeddings.pt"),
index_heading_entries=False,
user=default_user,
)
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(
OrgToJsonl,
get_sample_data("org"),
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
)
content_config.plugins = {
"plugin1": TextContentConfig(
input_files=[content_dir.joinpath("notes.jsonl.gz")],
input_filter=None,
compressed_jsonl=content_dir.joinpath("plugin.jsonl.gz"),
embeddings_file=content_dir.joinpath("plugin_embeddings.pt"),
)
}
text_search.setup(OrgToEntries, get_sample_data("org"), regenerate=False, user=default_user)
if os.getenv("GITHUB_PAT_TOKEN"):
content_config.github = GithubContentConfig(
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
repos=[
GithubRepoConfig(
owner="khoj-ai",
name="lantern",
branch="master",
)
],
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
GithubConfig.objects.create(
pat_token=os.getenv("GITHUB_PAT_TOKEN"),
user=default_user,
)
content_config.plaintext = TextContentConfig(
GithubRepoConfig.objects.create(
owner="khoj-ai",
name="lantern",
branch="master",
github_config=GithubConfig.objects.get(user=default_user),
)
LocalPlaintextConfig.objects.create(
input_files=None,
input_filter=["tests/data/plaintext/*.txt", "tests/data/plaintext/*.md", "tests/data/plaintext/*.html"],
compressed_jsonl=content_dir.joinpath("plaintext.jsonl.gz"),
embeddings_file=content_dir.joinpath("plaintext_embeddings.pt"),
)
content_config.github = GithubContentConfig(
pat_token=os.getenv("GITHUB_PAT_TOKEN", ""),
repos=[
GithubRepoConfig(
owner="khoj-ai",
name="lantern",
branch="master",
)
],
compressed_jsonl=content_dir.joinpath("github.jsonl.gz"),
embeddings_file=content_dir.joinpath("github_embeddings.pt"),
)
filters = [DateFilter(), WordFilter(), FileFilter()]
text_search.setup(
JsonlToJsonl,
None,
content_config.plugins["plugin1"],
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
user=default_user,
)
return content_config
@pytest.fixture(scope="session")
def md_content_config(tmp_path_factory):
content_dir = tmp_path_factory.mktemp("content")
# Generate Embeddings for Markdown Content
content_config = ContentConfig()
content_config.markdown = TextContentConfig(
def md_content_config():
markdown_config = LocalMarkdownConfig.objects.create(
input_files=None,
input_filter=["tests/data/markdown/*.markdown"],
compressed_jsonl=content_dir.joinpath("markdown.jsonl.gz"),
embeddings_file=content_dir.joinpath("markdown_embeddings.pt"),
)
return content_config
return markdown_config
@pytest.fixture(scope="session")
def processor_config(tmp_path_factory):
openai_api_key = os.getenv("OPENAI_API_KEY")
processor_dir = tmp_path_factory.mktemp("processor")
# The conversation processor is the only configured processor
# It needs an OpenAI API key to work.
if not openai_api_key:
return
# Setup conversation processor, if OpenAI API key is set
processor_config = ProcessorConfig()
processor_config.conversation = ConversationProcessorConfig(
openai=OpenAIProcessorConfig(api_key=openai_api_key),
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
)
return processor_config
@pytest.fixture(scope="session")
def processor_config_offline_chat(tmp_path_factory):
processor_dir = tmp_path_factory.mktemp("processor")
# Setup conversation processor
processor_config = ProcessorConfig()
offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True, chat_model="mistral-7b-instruct-v0.1.Q4_0.gguf")
processor_config.conversation = ConversationProcessorConfig(
offline_chat=offline_chat,
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
)
return processor_config
@pytest.fixture(scope="session")
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
@pytest.fixture(scope="function")
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.content_type = md_content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
state.SearchType = configure_search_types()
LocalMarkdownConfig.objects.create(
input_files=None,
input_filter=["tests/data/markdown/*.markdown"],
user=default_user2,
)
# Index Markdown Content for Search
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
all_files = fs_syncer.collect_files(state.config.content_type)
state.content_index = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models
all_files = fs_syncer.collect_files(user=default_user2)
state.content_index, _ = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
)
# Initialize Processor from Config
state.processor_config = configure_processor(processor_config)
if os.getenv("OPENAI_API_KEY"):
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
OpenAIProcessorConversationConfigFactory()
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
state.anonymous_mode = True
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
return TestClient(app)
@pytest.fixture(scope="function")
def client(content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.search_type = search_config
state.SearchType = configure_search_types()
# Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"):
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
OpenAIProcessorConversationConfigFactory()
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
state.anonymous_mode = True
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
return TestClient(app)
@pytest.fixture(scope="function")
def fastapi_app():
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
return app
@pytest.fixture(scope="function")
def client(
content_config: ContentConfig,
search_config: SearchConfig,
api_user: KhojApiUser,
):
state.config.content_type = content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
state.SearchType = configure_search_types()
state.embeddings_model = EmbeddingsModel()
state.cross_encoder_model = CrossEncoderModel()
# These lines help us Mock the Search models for these search types
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
state.content_index.org = text_search.setup(
OrgToJsonl,
text_search.setup(
OrgToEntries,
get_sample_data("org"),
content_config.org,
state.search_models.text_search.bi_encoder,
regenerate=False,
user=api_user.user,
)
state.content_index.image = image_search.setup(
content_config.image, state.search_models.image_search, regenerate=False
)
state.content_index.plaintext = text_search.setup(
PlaintextToJsonl,
text_search.setup(
PlaintextToEntries,
get_sample_data("plaintext"),
content_config.plaintext,
state.search_models.text_search.bi_encoder,
regenerate=False,
user=api_user.user,
)
state.processor_config = configure_processor(processor_config)
state.anonymous_mode = False
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
return TestClient(app)
@pytest.fixture(scope="function")
def client_offline_chat(
search_config: SearchConfig,
processor_config_offline_chat: ProcessorConfig,
content_config: ContentConfig,
md_content_config,
):
def client_offline_chat(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.content_type = md_content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
state.SearchType = configure_search_types()
# Index Markdown Content for Search
state.search_models.text_search = text_search.initialize_model(search_config.asymmetric)
state.search_models.image_search = image_search.initialize_model(search_config.image)
LocalMarkdownConfig.objects.create(
input_files=None,
input_filter=["tests/data/markdown/*.markdown"],
user=default_user2,
)
all_files = fs_syncer.collect_files(state.config.content_type)
state.content_index = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models
all_files = fs_syncer.collect_files(user=default_user2)
configure_content(
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
)
# Initialize Processor from Config
state.processor_config = configure_processor(processor_config_offline_chat)
OfflineChatProcessorConversationConfigFactory(enabled=True)
UserConversationProcessorConfigFactory(user=default_user2)
state.anonymous_mode = True
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
return TestClient(app)
@pytest.fixture(scope="function")
def new_org_file(content_config: ContentConfig):
def new_org_file(default_user: KhojUser, content_config: ContentConfig):
# Setup
new_org_file = Path(content_config.org.input_filter[0]).parent / "new_file.org"
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
input_filters = org_config.input_filter
new_org_file = Path(input_filters[0]).parent / "new_file.org"
new_org_file.touch()
yield new_org_file
@@ -311,11 +371,9 @@ def new_org_file(content_config: ContentConfig):
@pytest.fixture(scope="function")
def org_config_with_only_new_file(content_config: ContentConfig, new_org_file: Path):
new_org_config = deepcopy(content_config.org)
new_org_config.input_files = [f"{new_org_file}"]
new_org_config.input_filter = None
return new_org_config
def org_config_with_only_new_file(new_org_file: Path, default_user: KhojUser):
LocalOrgConfig.objects.update(input_files=[str(new_org_file)], input_filter=None)
return LocalOrgConfig.objects.filter(user=default_user).first()
@pytest.fixture(scope="function")

13
tests/data/config.yml vendored
View File

@@ -9,20 +9,9 @@ content-type:
input-filter:
- '*.org'
- ~/notes/*.org
plugins:
content_plugin_1:
compressed-jsonl: content_plugin_1.jsonl.gz
embeddings-file: content_plugin_1_embeddings.pt
input-files:
- content_plugin_1_new.jsonl.gz
content_plugin_2:
compressed-jsonl: content_plugin_2.jsonl.gz
embeddings-file: content_plugin_2_embeddings.pt
input-filter:
- '*2_new.jsonl.gz'
enable-offline-chat: false
search-type:
asymmetric:
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
encoder: sentence-transformers/msmarco-MiniLM-L-6-v3
version: 0.10.1
version: 0.15.0

BIN
tests/data/pdf/ocr_samples.pdf vendored Normal file

Binary file not shown.

92
tests/helpers.py Normal file
View File

@@ -0,0 +1,92 @@
import factory
import os
from database.models import (
KhojUser,
KhojApiUser,
ChatModelOptions,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig,
SearchModelConfig,
UserConversationConfig,
Conversation,
Subscription,
)
class UserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojUser
username = factory.Faker("name")
email = factory.Faker("email")
password = factory.Faker("password")
uuid = factory.Faker("uuid4")
class ApiUserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojApiUser
user = None
name = factory.Faker("name")
token = factory.Faker("password")
class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
class Meta:
model = ChatModelOptions
max_prompt_size = 2000
tokenizer = None
chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf"
model_type = "offline"
class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = UserConversationConfig
user = factory.SubFactory(UserFactory)
setting = factory.SubFactory(ChatModelOptionsFactory)
class OfflineChatProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = OfflineChatProcessorConversationConfig
enabled = True
class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = OpenAIProcessorConversationConfig
api_key = os.getenv("OPENAI_API_KEY")
class ConversationFactory(factory.django.DjangoModelFactory):
class Meta:
model = Conversation
user = factory.SubFactory(UserFactory)
class SearchModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = SearchModelConfig
name = "default"
model_type = "text"
bi_encoder = "thenlper/gte-small"
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2"
class SubscriptionFactory(factory.django.DjangoModelFactory):
class Meta:
model = Subscription
user = factory.SubFactory(UserFactory)
type = "standard"
is_recurring = False
renewal_date = "2100-04-01"

View File

@@ -25,7 +25,7 @@ def test_cli_invalid_config_file_path():
non_existent_config_file = f"non-existent-khoj-{random()}.yml"
# Act
actual_args = cli([f"-c={non_existent_config_file}"])
actual_args = cli([f"--config-file={non_existent_config_file}"])
# Assert
assert actual_args.config_file == resolve_absolute_path(non_existent_config_file)
@@ -35,7 +35,7 @@ def test_cli_invalid_config_file_path():
# ----------------------------------------------------------------------------------------------------
def test_cli_config_from_file():
# Act
actual_args = cli(["-c=tests/data/config.yml", "--regenerate", "-vvv"])
actual_args = cli(["--config-file=tests/data/config.yml", "--regenerate", "-vvv"])
# Assert
assert actual_args.config_file == resolve_absolute_path(Path("tests/data/config.yml"))
@@ -48,14 +48,3 @@ def test_cli_config_from_file():
Path("~/first_from_config.org"),
Path("~/second_from_config.org"),
]
assert len(actual_args.config.content_type.plugins.keys()) == 2
assert actual_args.config.content_type.plugins["content_plugin_1"].input_files == [
Path("content_plugin_1_new.jsonl.gz")
]
assert actual_args.config.content_type.plugins["content_plugin_2"].input_filter == ["*2_new.jsonl.gz"]
assert actual_args.config.content_type.plugins["content_plugin_1"].compressed_jsonl == Path(
"content_plugin_1.jsonl.gz"
)
assert actual_args.config.content_type.plugins["content_plugin_2"].embeddings_file == Path(
"content_plugin_2_embeddings.pt"
)

View File

@@ -2,69 +2,135 @@
from io import BytesIO
from PIL import Image
from urllib.parse import quote
import pytest
# External Packages
from fastapi.testclient import TestClient
from fastapi import FastAPI
import pytest
# Internal Packages
from khoj.main import app
from khoj.configure import configure_routes, configure_search_types
from khoj.utils import state
from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.processor.org_mode.org_to_entries import OrgToEntries
from database.models import KhojUser, KhojApiUser
from database.adapters import EntryAdapters
# Test
# ----------------------------------------------------------------------------------------------------
def test_search_with_invalid_content_type(client):
@pytest.mark.django_db(transaction=True)
def test_search_with_no_auth_key(client):
# Arrange
user_query = quote("How to call Khoj from Emacs?")
# Act
response = client.get(f"/api/search?q={user_query}&t=invalid_content_type")
response = client.get(f"/api/search?q={user_query}")
# Assert
assert response.status_code == 403
@pytest.mark.django_db(transaction=True)
def test_search_with_invalid_auth_key(client):
# Arrange
headers = {"Authorization": "Bearer invalid-token"}
user_query = quote("How to call Khoj from Emacs?")
# Act
response = client.get(f"/api/search?q={user_query}", headers=headers)
# Assert
assert response.status_code == 403
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_search_with_invalid_content_type(client):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
user_query = quote("How to call Khoj from Emacs?")
# Act
response = client.get(f"/api/search?q={user_query}&t=invalid_content_type", headers=headers)
# Assert
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_search_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plugin1"]:
headers = {"Authorization": "Bearer kk-secret"}
for content_type in ["all", "org", "markdown", "image", "pdf", "github", "notion", "plaintext"]:
# Act
response = client.get(f"/api/search?q=random&t={content_type}")
response = client.get(f"/api/search?q=random&t={content_type}", headers=headers)
# Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_with_no_auth_key(client):
# Arrange
files = get_sample_files_data()
# Act
response = client.post("/api/v1/index/update", files=files)
# Assert
assert response.status_code == 403
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_with_invalid_auth_key(client):
# Arrange
files = get_sample_files_data()
headers = {"Authorization": "Bearer kk-invalid-token"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 403
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_update_with_invalid_content_type(client):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.get(f"/api/update?t=invalid_content_type")
response = client.get(f"/api/update?t=invalid_content_type", headers=headers)
# Assert
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_regenerate_with_invalid_content_type(client):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.get(f"/api/update?force=true&t=invalid_content_type")
response = client.get(f"/api/update?force=true&t=invalid_content_type", headers=headers)
# Assert
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update(client):
# Arrange
files = get_sample_files_data()
headers = {"x-api-key": "secret"}
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
@@ -74,88 +140,75 @@ def test_index_update(client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_regenerate_with_valid_content_type(client):
for content_type in ["all", "org", "markdown", "image", "pdf", "notion", "plugin1"]:
for content_type in ["all", "org", "markdown", "image", "pdf", "notion"]:
# Arrange
files = get_sample_files_data()
headers = {"x-api-key": "secret"}
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post(f"/api/v1/index/update?t={content_type}", files=files, headers=headers)
# Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_regenerate_with_github_fails_without_pat(client):
# Act
response = client.get(f"/api/update?force=true&t=github")
headers = {"Authorization": "Bearer kk-secret"}
response = client.get(f"/api/update?force=true&t=github", headers=headers)
# Arrange
files = get_sample_files_data()
headers = {"x-api-key": "secret"}
# Act
response = client.post(f"/api/v1/index/update?t=github", files=files, headers=headers)
# Assert
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: github"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skip(reason="Flaky test on parallel test runs")
def test_get_configured_types_via_api(client):
@pytest.mark.django_db
def test_get_configured_types_via_api(client, sample_org_data):
# Act
response = client.get(f"/api/config/types")
text_search.setup(OrgToEntries, sample_org_data, regenerate=False)
enabled_types = EntryAdapters.get_unique_file_types(user=None).all().values_list("file_type", flat=True)
# Assert
assert response.status_code == 200
assert response.json() == ["all", "org", "image", "plaintext", "plugin1"]
assert list(enabled_types) == ["org"]
# ----------------------------------------------------------------------------------------------------
def test_get_configured_types_with_only_plugin_content_config(content_config):
@pytest.mark.django_db(transaction=True)
def test_get_api_config_types(client, sample_org_data, default_user: KhojUser):
# Arrange
config.content_type = ContentConfig()
config.content_type.plugins = content_config.plugins
state.SearchType = configure_search_types(config)
configure_routes(app)
client = TestClient(app)
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
# Act
response = client.get(f"/api/config/types")
response = client.get(f"/api/config/types", headers=headers)
# Assert
assert response.status_code == 200
assert response.json() == ["all", "plugin1"]
assert response.json() == ["all", "org", "image", "plaintext"]
# ----------------------------------------------------------------------------------------------------
def test_get_configured_types_with_no_plugin_content_config(content_config):
@pytest.mark.django_db(transaction=True)
def test_get_configured_types_with_no_content_config(fastapi_app: FastAPI):
# Arrange
config.content_type = content_config
config.content_type.plugins = None
state.SearchType = configure_search_types(config)
state.anonymous_mode = True
if state.config and state.config.content_type:
state.config.content_type = None
state.search_models = configure_search_types()
configure_routes(app)
client = TestClient(app)
# Act
response = client.get(f"/api/config/types")
# Assert
assert response.status_code == 200
assert "plugin1" not in response.json()
# ----------------------------------------------------------------------------------------------------
def test_get_configured_types_with_no_content_config():
# Arrange
config.content_type = ContentConfig()
state.SearchType = configure_search_types(config)
configure_routes(app)
client = TestClient(app)
configure_routes(fastapi_app)
client = TestClient(fastapi_app)
# Act
response = client.get(f"/api/config/types")
@@ -166,8 +219,10 @@ def test_get_configured_types_with_no_content_config():
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_image_search(client, content_config: ContentConfig, search_config: SearchConfig):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
search_models.image_search = image_search.initialize_model(search_config.image)
content_index.image = image_search.setup(
content_config.image, search_models.image_search.image_encoder, regenerate=False
@@ -180,7 +235,7 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
for query, expected_image_name in query_expected_image_pairs:
# Act
response = client.get(f"/api/search?q={query}&n=1&t=image")
response = client.get(f"/api/search?q={query}&n=1&t=image", headers=headers)
# Assert
assert response.status_code == 200
@@ -192,43 +247,57 @@ def test_image_search(client, content_config: ContentConfig, search_config: Sear
# ----------------------------------------------------------------------------------------------------
def test_notes_search(client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data):
@pytest.mark.django_db(transaction=True)
def test_notes_search(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
# Arrange
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search.bi_encoder, regenerate=False
)
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
user_query = quote("How to git install application?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true")
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
# Assert
assert response.status_code == 200
# assert actual_data contains "Khoj via Emacs" entry
assert len(response.json()) == 1, "Expected only 1 result"
search_result = response.json()[0]["entry"]
assert "git clone https://github.com/khoj-ai/khoj" in search_result
assert "git clone https://github.com/khoj-ai/khoj" in search_result, "Expected 'git clone' in search result"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_notes_search_no_results(client, search_config: SearchConfig, sample_org_data, default_user: KhojUser):
# Arrange
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
user_query = quote("How to find my goat?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
# Assert
assert response.status_code == 200
assert response.json() == [], "Expected no results"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_only_filters(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data, default_user: KhojUser
):
# Arrange
filters = [WordFilter(), FileFilter()]
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl,
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(
OrgToEntries,
sample_org_data,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
user=default_user,
)
user_query = quote('+"Emacs" file:"*.org"')
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 200
@@ -238,19 +307,15 @@ def test_notes_search_with_only_filters(
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_include_filter(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
):
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_include_filter(client, sample_org_data, default_user: KhojUser):
# Arrange
filters = [WordFilter()]
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, sample_org_data, content_config.org, search_models.text_search, regenerate=False, filters=filters
)
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
user_query = quote('How to git install application? +"Emacs"')
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 200
@@ -260,24 +325,20 @@ def test_notes_search_with_include_filter(
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_exclude_filter(
client, content_config: ContentConfig, search_config: SearchConfig, sample_org_data
):
@pytest.mark.django_db(transaction=True)
def test_notes_search_with_exclude_filter(client, sample_org_data, default_user: KhojUser):
# Arrange
filters = [WordFilter()]
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl,
headers = {"Authorization": "Bearer kk-secret"}
text_search.setup(
OrgToEntries,
sample_org_data,
content_config.org,
search_models.text_search.bi_encoder,
regenerate=False,
filters=filters,
user=default_user,
)
user_query = quote('How to git install application? -"clone"')
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org")
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 200
@@ -286,22 +347,56 @@ def test_notes_search_with_exclude_filter(
assert "clone" not in search_result
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
# Arrange
headers = {"Authorization": "Bearer kk-token"} # Token for default_user2
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
user_query = quote("How to git install application?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 403
# assert actual response has no data as the default_user is different from the user making the query (anonymous)
assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiUser):
# Arrange
token = api_user3.token
headers = {"Authorization": "Bearer " + token}
user_query = quote("How to git install application?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 200
# assert actual response has no data as the default_user3, though other users have data
assert len(response.json()) == 0
assert response.json() == []
def get_sample_files_data():
return {
"files": ("path/to/filename.org", "* practicing piano", "text/org"),
"files": ("path/to/filename1.org", "** top 3 reasons why I moved to SF", "text/org"),
"files": ("path/to/filename2.org", "* how to build a search engine", "text/org"),
"files": ("path/to/filename.pdf", "Moore's law does not apply to consumer hardware", "application/pdf"),
"files": ("path/to/filename1.pdf", "The sun is a ball of helium", "application/pdf"),
"files": ("path/to/filename2.pdf", "Effect of sunshine on baseline human happiness", "application/pdf"),
"files": ("path/to/filename.txt", "data,column,value", "text/plain"),
"files": ("path/to/filename1.txt", "<html>my first web page</html>", "text/plain"),
"files": ("path/to/filename2.txt", "2021-02-02 Journal Entry", "text/plain"),
"files": ("path/to/filename.md", "# Notes from client call", "text/markdown"),
"files": (
"path/to/filename1.md",
"## Studying anthropological records from the Fatimid caliphate",
"text/markdown",
return [
("files", ("path/to/filename.org", "* practicing piano", "text/org")),
("files", ("path/to/filename1.org", "** top 3 reasons why I moved to SF", "text/org")),
("files", ("path/to/filename2.org", "* how to build a search engine", "text/org")),
("files", ("path/to/filename.pdf", "Moore's law does not apply to consumer hardware", "application/pdf")),
("files", ("path/to/filename1.pdf", "The sun is a ball of helium", "application/pdf")),
("files", ("path/to/filename2.pdf", "Effect of sunshine on baseline human happiness", "application/pdf")),
("files", ("path/to/filename.txt", "data,column,value", "text/plain")),
("files", ("path/to/filename1.txt", "<html>my first web page</html>", "text/plain")),
("files", ("path/to/filename2.txt", "2021-02-02 Journal Entry", "text/plain")),
("files", ("path/to/filename.md", "# Notes from client call", "text/markdown")),
(
"files",
("path/to/filename1.md", "## Studying anthropological records from the Fatimid caliphate", "text/markdown"),
),
"files": ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown"),
}
("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")),
]

View File

@@ -1,53 +1,12 @@
# Standard Packages
import re
from datetime import datetime
from math import inf
# External Packages
import pytest
# Internal Packages
from khoj.search_filter.date_filter import DateFilter
from khoj.utils.rawconfig import Entry
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
def test_date_filter():
entries = [
Entry(compiled="Entry with no date", raw="Entry with no date"),
Entry(compiled="April Fools entry: 1984-04-01", raw="April Fools entry: 1984-04-01"),
Entry(compiled="Entry with date:1984-04-02", raw="Entry with date:1984-04-02"),
]
q_with_no_date_filter = "head tail"
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2}
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
assert ret_query == "head tail"
assert entry_indices == set()
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {1}
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == "head tail"
assert entry_indices == {1, 2}
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
@@ -56,8 +15,8 @@ def test_extract_date_range():
datetime(1984, 1, 5, 0, 0, 0).timestamp(),
datetime(1984, 1, 7, 0, 0, 0).timestamp(),
]
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [0, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), inf]
assert DateFilter().extract_date_range('head dt<="1984-01-01"') == [None, datetime(1984, 1, 2, 0, 0, 0).timestamp()]
assert DateFilter().extract_date_range('head dt>="1984-01-01"') == [datetime(1984, 1, 1, 0, 0, 0).timestamp(), None]
assert DateFilter().extract_date_range('head dt:"1984-01-01"') == [
datetime(1984, 1, 1, 0, 0, 0).timestamp(),
datetime(1984, 1, 2, 0, 0, 0).timestamp(),

View File

@@ -6,97 +6,73 @@ from khoj.utils.rawconfig import Entry
def test_no_file_filter():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = "head tail"
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == False
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_file_filter_with_non_existent_file():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"nonexistent.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {}
def test_single_file_filter():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"file 1.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 2}
def test_file_filter_with_partial_match():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"1.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 2}
def test_file_filter_with_regex_match():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head file:"*.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_multiple_file_filter():
# Arrange
file_filter = FileFilter()
entries = arrange_content()
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 1, 2, 3}
def test_get_file_filter_terms():
@@ -108,7 +84,7 @@ def test_get_file_filter_terms():
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
# Assert
assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"']
assert filter_terms == ["file 1\\.org", "/path/to/dir/.*\\.org"]
def arrange_content():

View File

@@ -9,8 +9,7 @@ from faker import Faker
# Internal Packages
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from khoj.utils import state
from tests.helpers import ConversationFactory
SKIP_TESTS = True
pytestmark = pytest.mark.skipif(
@@ -23,7 +22,7 @@ fake = Faker()
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):
def populate_chat_history(message_list, user):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, llm_message, context in message_list:
@@ -33,14 +32,15 @@ def populate_chat_history(message_list):
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
)
# Update Conversation Metadata Logs in Application State
state.processor_config.conversation.meta_log = conversation_log
# Update Conversation Metadata Logs in Database
ConversationFactory(user=user, conversation_log=conversation_log)
# Tests
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat):
# Act
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
@@ -56,13 +56,14 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_from_chat_history(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -78,7 +79,8 @@ def test_answer_from_chat_history(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_currently_retrieved_content(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@@ -88,7 +90,7 @@ def test_answer_from_currently_retrieved_content(client_offline_chat):
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"')
@@ -101,7 +103,8 @@ def test_answer_from_currently_retrieved_content(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@@ -111,7 +114,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
@@ -130,13 +133,14 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
reason="Chat director not capable of answering this question yet because it requires extract_questions",
)
@pytest.mark.chatquality
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
@@ -154,14 +158,15 @@ def test_answer_from_chat_history_and_currently_retrieved_content(client_offline
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true')
@@ -177,11 +182,12 @@ def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_using_general_command(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_using_general_command(client_offline_chat, default_user2):
# Arrange
query = urllib.parse.quote("/general Where was Xi Li born?")
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
@@ -194,11 +200,12 @@ def test_answer_using_general_command(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2):
# Arrange
query = urllib.parse.quote("/notes Where was Xi Li born?")
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
@@ -211,12 +218,13 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_using_file_filter(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_using_file_filter(client_offline_chat, default_user2):
# Arrange
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8")
@@ -229,11 +237,12 @@ def test_answer_using_file_filter(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_not_known_using_notes_command(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_not_known_using_notes_command(client_offline_chat, default_user2):
# Arrange
query = urllib.parse.quote("/notes Where was Testatron born?")
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
@@ -247,6 +256,7 @@ def test_answer_not_known_using_notes_command(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
@freeze_time("2023-04-01")
def test_answer_requires_current_date_awareness(client_offline_chat):
"Chat actor should be able to answer questions relative to current date using provided notes"
@@ -265,6 +275,7 @@ def test_answer_requires_current_date_awareness(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
@freeze_time("2023-04-01")
def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
@@ -280,14 +291,15 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(client_off
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(
@@ -307,7 +319,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(client
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2):
# Act
response = client_offline_chat.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
response_message = response.content.decode("utf-8")
@@ -328,14 +341,15 @@ def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -350,11 +364,12 @@ def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
@pytest.mark.chatquality
def test_answer_chat_history_very_long(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_chat_history_very_long(client_offline_chat, default_user2):
# Arrange
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -368,6 +383,7 @@ def test_answer_chat_history_very_long(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_requires_multiple_independent_searches(client_offline_chat):
"Chat director should be able to answer by doing multiple independent searches for required information"
# Act

View File

@@ -1,3 +1,14 @@
# Standard Packages
import numpy as np
import psutil
from scipy.stats import linregress
import secrets
# External Packages
import pytest
# Internal Packages
from khoj.processor.embeddings import EmbeddingsModel
from khoj.utils import helpers
@@ -44,3 +55,31 @@ def test_lru_cache():
cache["b"] # accessing 'b' makes it the most recently used item
cache["d"] = 4 # so 'c' is deleted from the cache instead of 'b'
assert cache == {"b": 2, "d": 4}
@pytest.mark.skip(reason="Memory leak exists on GPU, MPS devices")
def test_encode_docs_memory_leak():
# Arrange
iterations = 50
batch_size = 20
embeddings_model = EmbeddingsModel()
memory_usage_trend = []
device = f"{helpers.get_device()}".upper()
# Act
# Encode random strings repeatedly and record memory usage trend
for iteration in range(iterations):
random_docs = [" ".join(secrets.token_hex(5) for _ in range(10)) for _ in range(batch_size)]
a = [embeddings_model.embed_documents(random_docs)]
memory_usage_trend += [psutil.Process().memory_info().rss / (1024 * 1024)]
print(f"{iteration:02d}, {memory_usage_trend[-1]:.2f}", flush=True)
# Calculate slope of line fitting memory usage history
memory_usage_trend = np.array(memory_usage_trend)
slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
print(f"Memory usage increased at ~{slope:.2f} MB per iteration on {device}")
# Assert
# If slope is positive memory utilization is increasing
# Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
assert slope < 2, f"Memory leak suspected on {device}. Memory usage increased at ~{slope:.2f} MB per iteration"

View File

@@ -1,78 +0,0 @@
# Internal Packages
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.utils.rawconfig import Entry
def test_process_entries_from_single_input_jsonl(tmp_path):
"Convert multiple jsonl entries from single file to entries."
# Arrange
input_jsonl = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}
{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}
"""
input_jsonl_file = create_file(tmp_path, input_jsonl)
# Act
# Process Each Entry from All Notes Files
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file])
entries = list(map(Entry.from_dict, input_jsons))
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
# Assert
assert len(entries) == 2
assert output_jsonl == input_jsonl
def test_process_entries_from_multiple_input_jsonls(tmp_path):
"Convert multiple jsonl entries from single file to entries."
# Arrange
input_jsonl_1 = """{"raw": "raw input data 1", "compiled": "compiled input data 1", "heading": null, "file": "source/file/path1"}"""
input_jsonl_2 = """{"raw": "raw input data 2", "compiled": "compiled input data 2", "heading": null, "file": "source/file/path2"}"""
input_jsonl_file_1 = create_file(tmp_path, input_jsonl_1, filename="input1.jsonl")
input_jsonl_file_2 = create_file(tmp_path, input_jsonl_2, filename="input2.jsonl")
# Act
# Process Each Entry from All Notes Files
input_jsons = JsonlToJsonl.extract_jsonl_entries([input_jsonl_file_1, input_jsonl_file_2])
entries = list(map(Entry.from_dict, input_jsons))
output_jsonl = JsonlToJsonl.convert_entries_to_jsonl(entries)
# Assert
assert len(entries) == 2
assert output_jsonl == f"{input_jsonl_1}\n{input_jsonl_2}\n"
def test_get_jsonl_files(tmp_path):
"Ensure JSONL files specified via input-filter, input-files extracted"
# Arrange
# Include via input-filter globs
group1_file1 = create_file(tmp_path, filename="group1-file1.jsonl")
group1_file2 = create_file(tmp_path, filename="group1-file2.jsonl")
group2_file1 = create_file(tmp_path, filename="group2-file1.jsonl")
group2_file2 = create_file(tmp_path, filename="group2-file2.jsonl")
# Include via input-file field
file1 = create_file(tmp_path, filename="notes.jsonl")
# Not included by any filter
create_file(tmp_path, filename="not-included-jsonl.jsonl")
create_file(tmp_path, filename="not-included-text.txt")
expected_files = sorted(map(str, [group1_file1, group1_file2, group2_file1, group2_file2, file1]))
# Setup input-files, input-filters
input_files = [tmp_path / "notes.jsonl"]
input_filter = [tmp_path / "group1*.jsonl", tmp_path / "group2*.jsonl"]
# Act
extracted_org_files = JsonlToJsonl.get_jsonl_files(input_files, input_filter)
# Assert
assert len(extracted_org_files) == 5
assert extracted_org_files == expected_files
# Helper Functions
def create_file(tmp_path, entry=None, filename="test.jsonl"):
jsonl_file = tmp_path / filename
jsonl_file.touch()
if entry:
jsonl_file.write_text(entry)
return jsonl_file

View File

@@ -4,7 +4,7 @@ from pathlib import Path
import os
# Internal Packages
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.markdown.markdown_to_entries import MarkdownToEntries
from khoj.utils.fs_syncer import get_markdown_files
from khoj.utils.rawconfig import TextContentConfig
@@ -23,11 +23,11 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Markdown files
entry_nodes, file_to_entries = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
entry_nodes, file_to_entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
MarkdownToEntries.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@@ -52,11 +52,11 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Markdown files
entries, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
entries, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(
MarkdownToJsonl.convert_markdown_entries_to_maps(entries, entry_to_file_map)
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
MarkdownToEntries.convert_markdown_entries_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@@ -81,11 +81,11 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Markdown files
entry_strings, entry_to_file_map = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
entries = MarkdownToJsonl.convert_markdown_entries_to_maps(entry_strings, entry_to_file_map)
entry_strings, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
entries = MarkdownToEntries.convert_markdown_entries_to_maps(entry_strings, entry_to_file_map)
# Process Each Entry from All Notes Files
jsonl_string = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@@ -144,7 +144,7 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Act
# Extract Entries from specified Markdown files
entries, _ = MarkdownToJsonl.extract_markdown_entries(markdown_files=data)
entries, _ = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
# Assert
assert len(entries) == 2

View File

@@ -0,0 +1,111 @@
# Standard Modules
from io import BytesIO
from PIL import Image
from urllib.parse import quote
import pytest
# External Packages
from fastapi.testclient import TestClient
from fastapi import FastAPI, UploadFile
from io import BytesIO
import pytest
# Internal Packages
from khoj.configure import configure_routes, configure_search_types
from khoj.utils import state
from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_entries import OrgToEntries
from database.models import KhojUser, KhojApiUser
from database.adapters import EntryAdapters
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_search_for_user2_returns_empty(client, api_user2: KhojApiUser):
token = api_user2.token
headers = {"Authorization": f"Bearer {token}"}
for content_type in ["all", "org", "markdown", "pdf", "github", "notion", "plaintext"]:
# Act
response = client.get(f"/api/search?q=random&t={content_type}", headers=headers)
# Assert
assert response.text == "[]"
assert response.status_code == 200, f"Returned status: {response.status_code} for content type: {content_type}"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_with_user2(client, api_user2: KhojApiUser):
# Arrange
files = get_sample_files_data()
source_file_symbol = set([f[1][0] for f in files])
headers = {"Authorization": f"Bearer {api_user2.token}"}
update_response = client.post("/api/v1/index/update", files=files, headers=headers)
search_response = client.get("/api/search?q=hardware&t=all", headers=headers)
results = search_response.json()
# Assert
assert update_response.status_code == 200
assert len(results) == 5
for result in results:
assert result["additional"]["file"] in source_file_symbol
@pytest.mark.django_db(transaction=True)
def test_index_update_with_user2_inaccessible_user1(client, api_user2: KhojApiUser, api_user: KhojApiUser):
# Arrange
files = get_sample_files_data()
source_file_symbol = set([f[1][0] for f in files])
headers = {"Authorization": f"Bearer {api_user2.token}"}
update_response = client.post("/api/v1/index/update", files=files, headers=headers)
# Act
headers = {"Authorization": f"Bearer {api_user.token}"}
search_response = client.get("/api/search?q=hardware&t=all", headers=headers)
results = search_response.json()
# Assert
assert update_response.status_code == 200
assert len(results) == 4
for result in results:
assert result["additional"]["file"] not in source_file_symbol
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_different_user_data_not_accessed(client, sample_org_data, default_user: KhojUser):
# Arrange
headers = {"Authorization": "Bearer kk-token"} # Token for default_user2
text_search.setup(OrgToEntries, sample_org_data, regenerate=False, user=default_user)
user_query = quote("How to git install application?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 403
# assert actual response has no data as the default_user is different from the user making the query (anonymous)
assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden"
def get_sample_files_data():
return [
("files", ("path/to/filename.org", "* practicing piano", "text/org")),
("files", ("path/to/filename1.org", "** top 3 reasons why I moved to SF", "text/org")),
("files", ("path/to/filename2.org", "* how to build a search engine", "text/org")),
("files", ("path/to/filename.pdf", "Moore's law does not apply to consumer hardware", "application/pdf")),
("files", ("path/to/filename1.pdf", "The sun is a ball of helium", "application/pdf")),
("files", ("path/to/filename2.pdf", "Effect of sunshine on baseline human happiness", "application/pdf")),
("files", ("path/to/filename.txt", "data,column,value", "text/plain")),
("files", ("path/to/filename1.txt", "<html>my first web page</html>", "text/plain")),
("files", ("path/to/filename2.txt", "2021-02-02 Journal Entry", "text/plain")),
("files", ("path/to/filename.md", "# Notes from client call", "text/markdown")),
(
"files",
("path/to/filename1.md", "## Studying anthropological records from the Fatimid caliphate", "text/markdown"),
),
("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")),
]

View File

@@ -9,8 +9,8 @@ from khoj.processor.conversation import prompts
# Internal Packages
from khoj.processor.conversation.utils import message_to_log
from khoj.utils import state
from tests.helpers import ConversationFactory
from database.models import KhojUser
# Initialize variables for tests
api_key = os.getenv("OPENAI_API_KEY")
@@ -23,7 +23,7 @@ if api_key is None:
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):
def populate_chat_history(message_list, user=None):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, gpt_message, context in message_list:
@@ -33,13 +33,14 @@ def populate_chat_history(message_list):
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
)
# Update Conversation Metadata Logs in Application State
state.processor_config.conversation.meta_log = conversation_log
# Update Conversation Metadata Logs in Database
ConversationFactory(user=user, conversation_log=conversation_log)
# Tests
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
@@ -54,14 +55,15 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_from_chat_history(chat_client):
def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -76,8 +78,9 @@ def test_answer_from_chat_history(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_from_currently_retrieved_content(chat_client):
def test_answer_from_currently_retrieved_content(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@@ -87,7 +90,7 @@ def test_answer_from_currently_retrieved_content(chat_client):
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
@@ -99,8 +102,9 @@ def test_answer_from_currently_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_no_background, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@@ -110,10 +114,10 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"')
response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
response_message = response.content.decode("utf-8")
# Assert
@@ -125,14 +129,15 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"')
@@ -148,15 +153,16 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser):
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
@@ -171,12 +177,13 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_using_general_command(chat_client):
def test_answer_using_general_command(chat_client, default_user2: KhojUser):
# Arrange
query = urllib.parse.quote("/general Where was Xi Li born?")
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f"/api/chat?q={query}&stream=true")
@@ -188,12 +195,13 @@ def test_answer_using_general_command(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_from_retrieved_content_using_notes_command(chat_client):
def test_answer_from_retrieved_content_using_notes_command(chat_client, default_user2: KhojUser):
# Arrange
query = urllib.parse.quote("/notes Where was Xi Li born?")
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f"/api/chat?q={query}&stream=true")
@@ -205,24 +213,26 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_not_known_using_notes_command(chat_client):
def test_answer_not_known_using_notes_command(chat_client_no_background, default_user2: KhojUser):
# Arrange
query = urllib.parse.quote("/notes Where was Testatron born?")
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f"/api/chat?q={query}&stream=true")
response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
assert response_message == prompts.no_notes_found.format()
assert response_message == prompts.no_entries_found.format()
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
@freeze_time("2023-04-01")
def test_answer_requires_current_date_awareness(chat_client):
@@ -240,11 +250,13 @@ def test_answer_requires_current_date_awareness(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
@freeze_time("2023-04-01")
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
# Act
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
response_message = response.content.decode("utf-8")
@@ -254,15 +266,16 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client):
def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(
@@ -280,10 +293,12 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
# Act
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@@ -292,6 +307,8 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
"which one is",
"which of namita's sons",
"the birth order",
"provide more context",
"provide me with more context",
]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
@@ -301,15 +318,16 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_in_chat_history_beyond_lookback_window(chat_client):
def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
@@ -324,6 +342,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information"
@@ -340,10 +359,12 @@ def test_answer_requires_multiple_independent_searches(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_answer_using_file_filter(chat_client):
"Chat should be able to use search filters in the query"
# Act
query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"')
response = chat_client.get(f"/api/chat?q={query}&stream=true")
response_message = response.content.decode("utf-8")

View File

@@ -3,8 +3,8 @@ import json
import os
# Internal Packages
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.text_to_jsonl import TextToJsonl
from khoj.processor.org_mode.org_to_entries import OrgToEntries
from khoj.processor.text_to_entries import TextToEntries
from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import Entry
from khoj.utils.fs_syncer import get_org_files
@@ -29,9 +29,9 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
for index_heading_entries in [True, False]:
# Act
# Extract entries into jsonl from specified Org files
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
OrgToJsonl.convert_org_nodes_to_entries(
*OrgToJsonl.extract_org_entries(org_files=data), index_heading_entries=index_heading_entries
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
OrgToEntries.convert_org_nodes_to_entries(
*OrgToEntries.extract_org_entries(org_files=data), index_heading_entries=index_heading_entries
)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@@ -59,12 +59,12 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
# Act
# Extract Entries from specified Org files
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=data)
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
# Split each entry from specified Org files by max words
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
TextToJsonl.split_entries_by_max_tokens(
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
TextToEntries.split_entries_by_max_tokens(
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@@ -86,7 +86,7 @@ def test_entry_split_drops_large_words():
# Act
# Split entry by max words and drop words larger than max word length
processed_entry = TextToJsonl.split_entries_by_max_tokens([entry], max_word_length=5)[0]
processed_entry = TextToEntries.split_entries_by_max_tokens([entry], max_word_length=5)[0]
# Assert
# "Heading" dropped from compiled version because its over the set max word limit
@@ -109,11 +109,11 @@ def test_entry_with_body_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Org files
entries, entry_to_file_map = OrgToJsonl.extract_org_entries(org_files=data)
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
# Process Each Entry from All Notes Files
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map)
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@@ -136,11 +136,11 @@ Intro text
# Act
# Extract Entries from specified Org files
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=data)
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
# Process Each Entry from All Notes Files
entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(entries)
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@@ -160,11 +160,11 @@ def test_file_with_no_headings_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Org files
entry_nodes, file_to_entries = OrgToJsonl.extract_org_entries(org_files=data)
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
# Process Each Entry from All Notes Files
entries = OrgToJsonl.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(entries)
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@@ -224,7 +224,7 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Act
# Extract Entries from specified Org files
entries, _ = OrgToJsonl.extract_org_entries(org_files=data)
entries, _ = OrgToEntries.extract_org_entries(org_files=data)
# Assert
assert len(entries) == 2

View File

@@ -3,7 +3,7 @@ import json
import os
# Internal Packages
from khoj.processor.pdf.pdf_to_jsonl import PdfToJsonl
from khoj.processor.pdf.pdf_to_entries import PdfToEntries
from khoj.utils.fs_syncer import get_pdf_files
from khoj.utils.rawconfig import TextContentConfig
@@ -18,11 +18,11 @@ def test_single_page_pdf_to_jsonl():
pdf_bytes = f.read()
data = {"tests/data/pdf/singlepage.pdf": pdf_bytes}
entries, entry_to_file_map = PdfToJsonl.extract_pdf_entries(pdf_files=data)
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Process Each Entry from All Pdf Files
jsonl_string = PdfToJsonl.convert_pdf_maps_to_jsonl(
PdfToJsonl.convert_pdf_entries_to_maps(entries, entry_to_file_map)
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@@ -38,11 +38,11 @@ def test_multi_page_pdf_to_jsonl():
pdf_bytes = f.read()
data = {"tests/data/pdf/multipage.pdf": pdf_bytes}
entries, entry_to_file_map = PdfToJsonl.extract_pdf_entries(pdf_files=data)
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Process Each Entry from All Pdf Files
jsonl_string = PdfToJsonl.convert_pdf_maps_to_jsonl(
PdfToJsonl.convert_pdf_entries_to_maps(entries, entry_to_file_map)
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
@@ -50,6 +50,23 @@ def test_multi_page_pdf_to_jsonl():
assert len(jsonl_data) == 6
def test_ocr_page_pdf_to_jsonl():
"Convert multiple pages from single PDF file to jsonl."
# Act
# Extract Entries from specified Pdf files
with open("tests/data/pdf/ocr_samples.pdf", "rb") as f:
pdf_bytes = f.read()
data = {"tests/data/pdf/ocr_samples.pdf": pdf_bytes}
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
# Process Each Entry from All Pdf Files
entries = PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
assert len(entries) == 1
assert "playing on a strip of marsh" in entries[0].raw
def test_get_pdf_files(tmp_path):
"Ensure Pdf files specified via input-filter, input-files extracted"
# Arrange

View File

@@ -6,7 +6,8 @@ from pathlib import Path
# Internal Packages
from khoj.utils.fs_syncer import get_plaintext_files
from khoj.utils.rawconfig import TextContentConfig
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
from database.models import LocalPlaintextConfig, KhojUser
def test_plaintext_file(tmp_path):
@@ -26,14 +27,14 @@ def test_plaintext_file(tmp_path):
f"{plaintextfile}": entry,
}
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(entry_to_file_map=data)
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(entry_to_file_map=data)
# Convert each entry.file to absolute path to make them JSON serializable
for map in maps:
map.file = str(Path(map.file).absolute())
# Process Each Entry from All Notes Files
jsonl_string = PlaintextToJsonl.convert_entries_to_jsonl(maps)
jsonl_string = PlaintextToEntries.convert_entries_to_jsonl(maps)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert
@@ -91,14 +92,15 @@ def test_get_plaintext_files(tmp_path):
assert set(extracted_plaintext_files.keys()) == set(expected_files)
def test_parse_html_plaintext_file(content_config):
def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
"Ensure HTML files are parsed correctly"
# Arrange
# Setup input-files, input-filters
extracted_plaintext_files = get_plaintext_files(content_config.plaintext)
config = LocalPlaintextConfig.objects.filter(user=default_user).first()
extracted_plaintext_files = get_plaintext_files(config=config)
# Act
maps = PlaintextToJsonl.convert_plaintext_entries_to_maps(extracted_plaintext_files)
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(extracted_plaintext_files)
# Assert
assert len(maps) == 1

View File

@@ -1,25 +1,27 @@
# System Packages
import logging
import locale
from pathlib import Path
import os
import asyncio
# External Packages
import pytest
# Internal Packages
from khoj.utils.state import content_index, search_models
from khoj.search_type import text_search
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.utils.config import SearchModels
from khoj.utils.fs_syncer import get_org_files
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_entries import OrgToEntries
from khoj.processor.github.github_to_entries import GithubToEntries
from khoj.utils.fs_syncer import collect_files, get_org_files
from database.models import LocalOrgConfig, KhojUser, Entry, GithubConfig
logger = logging.getLogger(__name__)
# Test
# ----------------------------------------------------------------------------------------------------
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: TextContentConfig):
@pytest.mark.django_db
def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_new_file: LocalOrgConfig):
# Arrange
# Ensure file mentioned in org.input-files is missing
single_new_file = Path(org_config_with_only_new_file.input_files[0])
@@ -32,98 +34,148 @@ def test_text_search_setup_with_missing_file_raises_error(org_config_with_only_n
# ----------------------------------------------------------------------------------------------------
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path: Path):
@pytest.mark.django_db
def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, default_user: KhojUser):
# Arrange
orgfile = tmp_path / "directory.org" / "file.org"
orgfile.parent.mkdir()
with open(orgfile, "w") as f:
f.write("* Heading\n- List item\n")
org_content_config = TextContentConfig(
input_filter=[f"{tmp_path}/**/*"], compressed_jsonl="test.jsonl", embeddings_file="test.pt"
LocalOrgConfig.objects.create(
input_filter=[f"{tmp_path}/**/*"],
input_files=None,
user=default_user,
)
# Act
# should not raise IsADirectoryError and return orgfile
assert get_org_files(org_content_config) == {f"{orgfile}": "* Heading\n- List item\n"}
org_files = collect_files(user=default_user)["org"]
# Assert
# should return orgfile and not raise IsADirectoryError
assert org_files == {f"{orgfile}": "* Heading\n- List item\n"}
# ----------------------------------------------------------------------------------------------------
def test_text_search_setup_with_empty_file_raises_error(
org_config_with_only_new_file: TextContentConfig, search_config: SearchConfig
@pytest.mark.django_db
def test_text_search_setup_with_empty_file_creates_no_entries(
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange
data = get_org_files(org_config_with_only_new_file)
# Act
# Generate notes embeddings during asymmetric setup
with pytest.raises(ValueError, match=r"^No valid entries found*"):
text_search.setup(OrgToJsonl, data, org_config_with_only_new_file, search_config.asymmetric, regenerate=True)
# ----------------------------------------------------------------------------------------------------
def test_text_search_setup(content_config: ContentConfig, search_models: SearchModels):
# Arrange
data = get_org_files(content_config.org)
# Act
# Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
assert len(notes_model.entries) == 10
assert len(notes_model.corpus_embeddings) == 10
assert "Deleted 3 entries. Created 0 new entries for user " in caplog.records[-1].message
verify_embeddings(0, default_user)
# ----------------------------------------------------------------------------------------------------
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, search_models: SearchModels, caplog):
@pytest.mark.django_db
def test_text_indexer_deletes_embedding_before_regenerate(
content_config: ContentConfig, default_user: KhojUser, caplog
):
# Arrange
caplog.set_level(logging.INFO, logger="khoj")
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
data = get_org_files(content_config.org)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
assert "Deleting all entries for file type org" in caplog.text
assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog):
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
# Act
# Generate notes embeddings during asymmetric setup
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
# Assert
assert "Deleted 3 entries. Created 10 new entries for user " in caplog.records[-1].message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
# Arrange
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
# Act
# Generate initial notes embeddings during asymmetric setup
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True)
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# Run asymmetric setup again with no changes to data source. Ensure index is not updated
text_search.setup(OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False)
with caplog.at_level(logging.DEBUG):
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert
assert "Creating index from scratch." in initial_logs
assert "Creating index from scratch." not in final_logs
assert "Deleting all entries for file type org" in initial_logs
assert "Deleting all entries for file type org" not in final_logs
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
@pytest.mark.anyio
async def test_text_search(content_config: ContentConfig, search_config: SearchConfig):
# @pytest.mark.asyncio
async def test_text_search(search_config: SearchConfig):
# Arrange
data = get_org_files(content_config.org)
search_models.text_search = text_search.initialize_model(search_config.asymmetric)
content_index.org = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
default_user = await KhojUser.objects.acreate(
username="test_user", password="test_password", email="test@example.com"
)
org_config = await LocalOrgConfig.objects.acreate(
input_files=None,
input_filter=["tests/data/org/*.org"],
index_heading_entries=False,
user=default_user,
)
data = get_org_files(org_config)
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
text_search.setup,
OrgToEntries,
data,
True,
True,
default_user,
)
query = "How to git install application?"
# Act
hits, entries = await text_search.query(
query, search_model=search_models.text_search, content=content_index.org, rank_results=True
)
results = text_search.collate_results(hits, entries, count=1)
hits = await text_search.query(default_user, query)
results = text_search.collate_results(hits)
results = sorted(results, key=lambda x: float(x.score))[:1]
# Assert
# search results should contain "git clone" entry
search_result = results[0].entry
assert "git clone" in search_result
assert "git clone" in search_result, 'search result did not contain "git clone" entry'
# ----------------------------------------------------------------------------------------------------
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
@pytest.mark.django_db
def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
max_tokens = 256
@@ -137,47 +189,45 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: TextContent
# Act
# reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
# Assert
# verify newly added org-mode entry is split by max tokens
assert len(initial_notes_model.entries) == 2
assert len(initial_notes_model.corpus_embeddings) == 2
assert (
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens"
# ----------------------------------------------------------------------------------------------------
# @pytest.mark.skip(reason="Flaky due to compressed_jsonl file being rewritten by other tests")
@pytest.mark.django_db
def test_entry_chunking_by_max_tokens_not_full_corpus(
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange
# Insert org-mode entry with size exceeding max token limit to new org file
data = {
"readme.org": """
* Khoj
/Allow natural language search on user content like notes, images using transformer based models/
/Allow natural language search on user content like notes, images using transformer based models/
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
All data is processed locally. User can interface with khoj app via [[./interface/emacs/khoj.el][Emacs]], API or Commandline
** Dependencies
- Python3
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
- Python3
- [[https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links][Miniconda]]
** Install
#+begin_src shell
git clone https://github.com/khoj-ai/khoj && cd khoj
conda env create -f environment.yml
conda activate khoj
#+end_src"""
#+begin_src shell
git clone https://github.com/khoj-ai/khoj && cd khoj
conda env create -f environment.yml
conda activate khoj
#+end_src"""
}
text_search.setup(
OrgToJsonl,
OrgToEntries,
data,
org_config_with_only_new_file,
search_models.text_search.bi_encoder,
regenerate=False,
user=default_user,
)
max_tokens = 256
@@ -191,64 +241,58 @@ def test_entry_chunking_by_max_tokens_not_full_corpus(
# Act
# reload embeddings, entries, notes model after adding new org-mode file
initial_notes_model = text_search.setup(
OrgToJsonl,
data,
org_config_with_only_new_file,
search_models.text_search.bi_encoder,
regenerate=False,
full_corpus=False,
)
with caplog.at_level(logging.INFO):
text_search.setup(
OrgToEntries,
data,
regenerate=False,
full_corpus=False,
user=default_user,
)
# Assert
# verify newly added org-mode entry is split by max tokens
assert len(initial_notes_model.entries) == 5
assert len(initial_notes_model.corpus_embeddings) == 5
assert (
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_regenerate_index_with_new_entry(
content_config: ContentConfig, search_models: SearchModels, new_org_file: Path
content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
):
# Arrange
data = get_org_files(content_config.org)
initial_notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
initial_data = get_org_files(org_config)
# append org-mode entry to first org input file in config
content_config.org.input_files = [f"{new_org_file}"]
org_config.input_files = [f"{new_org_file}"]
with open(new_org_file, "w") as f:
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
data = get_org_files(content_config.org)
final_data = get_org_files(org_config)
# Act
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
final_logs = caplog.text
# Assert
assert len(regenerated_notes_model.entries) == 11
assert len(regenerated_notes_model.corpus_embeddings) == 11
# verify new entry appended to index, without disrupting order or content of existing entries
error_details = compare_index(initial_notes_model, regenerated_notes_model)
if error_details:
pytest.fail(error_details, False)
# Cleanup
# reset input_files in config to empty list
content_config.org.input_files = []
assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
assert "Deleted 10 entries. Created 11 new entries for user " in final_logs
verify_embeddings(11, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db
def test_update_index_with_duplicate_entries_in_stable_order(
org_config_with_only_new_file: TextContentConfig, search_models: SearchModels
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
):
# Arrange
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
@@ -261,31 +305,30 @@ def test_update_index_with_duplicate_entries_in_stable_order(
data = get_org_files(org_config_with_only_new_file)
# Act
# load embeddings, entries, notes model after adding new org-mode file
initial_index = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
)
# generate embeddings, entries, notes model from scratch after adding new org-mode file
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model after adding new org-mode file
updated_index = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
# update embeddings, entries, notes model with no new changes
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert len(initial_index.entries) == len(updated_index.entries) == 1
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) == 1
assert "Deleted 3 entries. Created 1 new entries for user " in initial_logs
assert "Deleted 0 entries. Created 0 new entries for user " in final_logs
# verify the same entry is added even when there are multiple duplicate entries
error_details = compare_index(initial_index, updated_index)
if error_details:
pytest.fail(error_details)
verify_embeddings(1, default_user)
# ----------------------------------------------------------------------------------------------------
def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextContentConfig, search_models: SearchModels):
@pytest.mark.django_db
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
# Arrange
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
@@ -293,101 +336,84 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: TextCont
new_entry = "* TODO A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}{new_entry} -- Tatooine")
data = get_org_files(org_config_with_only_new_file)
# load embeddings, entries, notes model after adding new org file with 2 entries
initial_index = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=True
)
initial_data = get_org_files(org_config_with_only_new_file)
# update embeddings, entries, notes model after removing an entry from the org file
with open(new_file_to_index, "w") as f:
f.write(f"{new_entry}")
data = get_org_files(org_config_with_only_new_file)
final_data = get_org_files(org_config_with_only_new_file)
# Act
updated_index = text_search.setup(
OrgToJsonl, data, org_config_with_only_new_file, search_models.text_search.bi_encoder, regenerate=False
)
# load embeddings, entries, notes model after adding new org file with 2 entries
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert
# verify only 1 entry added even if there are multiple duplicate entries
assert len(initial_index.entries) == len(updated_index.entries) + 1
assert len(initial_index.corpus_embeddings) == len(updated_index.corpus_embeddings) + 1
assert "Deleted 3 entries. Created 2 new entries for user " in initial_logs
assert "Deleted 1 entries. Created 0 new entries for user " in final_logs
# verify the same entry is added even when there are multiple duplicate entries
error_details = compare_index(updated_index, initial_index)
if error_details:
pytest.fail(error_details)
verify_embeddings(1, default_user)
# ----------------------------------------------------------------------------------------------------
def test_update_index_with_new_entry(content_config: ContentConfig, search_models: SearchModels, new_org_file: Path):
@pytest.mark.django_db
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
# Arrange
data = get_org_files(content_config.org)
initial_notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=True, normalize=False
)
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
data = get_org_files(org_config)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
initial_logs = caplog.text
caplog.clear() # Clear logs
# append org-mode entry to first org input file in config
with open(new_org_file, "w") as f:
new_entry = "\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n"
f.write(new_entry)
data = get_org_files(content_config.org)
data = get_org_files(org_config)
# Act
# update embeddings, entries with the newly added note
content_config.org.input_files = [f"{new_org_file}"]
final_notes_model = text_search.setup(
OrgToJsonl, data, content_config.org, search_models.text_search.bi_encoder, regenerate=False, normalize=False
)
with caplog.at_level(logging.INFO):
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
final_logs = caplog.text
# Assert
assert len(final_notes_model.entries) == len(initial_notes_model.entries) + 1
assert len(final_notes_model.corpus_embeddings) == len(initial_notes_model.corpus_embeddings) + 1
# verify new entry appended to index, without disrupting order or content of existing entries
error_details = compare_index(initial_notes_model, final_notes_model)
if error_details:
pytest.fail(error_details, False)
# Cleanup
# reset input_files in config to empty list
content_config.org.input_files = []
assert "Deleted 3 entries. Created 10 new entries for user " in initial_logs
assert "Deleted 0 entries. Created 1 new entries for user " in final_logs
verify_embeddings(11, default_user)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("GITHUB_PAT_TOKEN") is None, reason="GITHUB_PAT_TOKEN not set")
def test_text_search_setup_github(content_config: ContentConfig, search_models: SearchModels):
def test_text_search_setup_github(content_config: ContentConfig, default_user: KhojUser):
# Arrange
github_config = GithubConfig.objects.filter(user=default_user).first()
# Act
# Regenerate github embeddings to test asymmetric setup without caching
github_model = text_search.setup(
GithubToJsonl, content_config.github, search_models.text_search.bi_encoder, regenerate=True
text_search.setup(
GithubToEntries,
{},
regenerate=True,
user=default_user,
config=github_config,
)
# Assert
assert len(github_model.entries) > 1
embeddings = Entry.objects.filter(user=default_user, file_type="github").count()
assert embeddings > 1
def compare_index(initial_notes_model, final_notes_model):
mismatched_entries, mismatched_embeddings = [], []
for index in range(len(initial_notes_model.entries)):
if initial_notes_model.entries[index].to_json() != final_notes_model.entries[index].to_json():
mismatched_entries.append(index)
# verify new entry embedding appended to embeddings tensor, without disrupting order or content of existing embeddings
for index in range(len(initial_notes_model.corpus_embeddings)):
if not initial_notes_model.corpus_embeddings[index].allclose(final_notes_model.corpus_embeddings[index]):
mismatched_embeddings.append(index)
error_details = ""
if mismatched_entries:
mismatched_entries_str = ",".join(map(str, mismatched_entries))
error_details += f"Entries at {mismatched_entries_str} not equal\n"
if mismatched_embeddings:
mismatched_embeddings_str = ", ".join(map(str, mismatched_embeddings))
error_details += f"Embeddings at {mismatched_embeddings_str} not equal\n"
return error_details
def verify_embeddings(expected_count, user):
embeddings = Entry.objects.filter(user=user, file_type="org").count()
assert embeddings == expected_count

View File

@@ -2,6 +2,7 @@
from khoj.search_filter.word_filter import WordFilter
from khoj.utils.rawconfig import Entry
# Test
# ----------------------------------------------------------------------------------------------------
def test_no_word_filter():
@@ -21,54 +22,44 @@ def test_no_word_filter():
# ----------------------------------------------------------------------------------------------------
def test_word_exclude_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
q_with_exclude_filter = 'head -"exclude_word" tail'
# Act
can_filter = word_filter.can_filter(q_with_exclude_filter)
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {0, 2}
# ----------------------------------------------------------------------------------------------------
def test_word_include_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
query_with_include_filter = 'head +"include_word" tail'
# Act
can_filter = word_filter.can_filter(query_with_include_filter)
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {2, 3}
# ----------------------------------------------------------------------------------------------------
def test_word_include_and_exclude_filter():
# Arrange
word_filter = WordFilter()
entries = arrange_content()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
# Assert
assert can_filter == True
assert ret_query == "head tail"
assert entry_indices == {2}
# ----------------------------------------------------------------------------------------------------