mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-09 13:25:11 +00:00
Update Text Chunking Strategy to Improve Search Context (#645)
## Major - Parse markdown, org parent entries as single entry if fit within max tokens - Parse a file as single entry if it fits with max token limits - Add parent heading ancestry to extracted markdown entries for context - Chunk text in preference order of para, sentence, word, character ## Minor - Create wrapper function to get entries from org, md, pdf & text files - Remove unused Entry to Jsonl converter from text to entry class, tests - Dedupe code by using single func to process an org file into entries Resolves #620
This commit is contained in:
@@ -306,7 +306,7 @@ def test_notes_search(client, search_config: SearchConfig, sample_org_data, defa
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
@@ -325,7 +325,7 @@ def test_notes_search_no_results(client, search_config: SearchConfig, sample_org
|
||||
user_query = quote("How to find my goat?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
@@ -409,7 +409,7 @@ def test_notes_search_requires_parent_context(
|
||||
user_query = quote("Install Khoj on Emacs")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.18", headers=headers)
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org&r=true&max_distance=0.22", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -7,8 +6,8 @@ from khoj.utils.fs_syncer import get_markdown_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
|
||||
|
||||
def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
||||
"Convert files with no heading to jsonl."
|
||||
def test_extract_markdown_with_no_headings(tmp_path):
|
||||
"Convert markdown file with no heading to entry format."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
- Bullet point 1
|
||||
@@ -17,30 +16,24 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
expected_heading = f"# {tmp_path.stem}"
|
||||
expected_heading = f"# {tmp_path}"
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entry_nodes, file_to_entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
|
||||
MarkdownToEntries.convert_markdown_entries_to_maps(entry_nodes, file_to_entries)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
# Ensure raw entry with no headings do not get heading prefix prepended
|
||||
assert not jsonl_data[0]["raw"].startswith("#")
|
||||
assert not entries[0].raw.startswith("#")
|
||||
# Ensure compiled entry has filename prepended as top level heading
|
||||
assert expected_heading in jsonl_data[0]["compiled"]
|
||||
assert entries[0].compiled.startswith(expected_heading)
|
||||
# Ensure compiled entry also includes the file name
|
||||
assert str(tmp_path) in jsonl_data[0]["compiled"]
|
||||
assert str(tmp_path) in entries[0].compiled
|
||||
|
||||
|
||||
def test_single_markdown_entry_to_jsonl(tmp_path):
|
||||
"Convert markdown entry from single file to jsonl."
|
||||
def test_extract_single_markdown_entry(tmp_path):
|
||||
"Convert markdown from single file to entry format."
|
||||
# Arrange
|
||||
entry = f"""### Heading
|
||||
\t\r
|
||||
@@ -52,20 +45,14 @@ def test_single_markdown_entry_to_jsonl(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(
|
||||
MarkdownToEntries.convert_markdown_entries_to_maps(entries, entry_to_file_map)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
|
||||
|
||||
def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
||||
"Convert multiple markdown entries from single file to jsonl."
|
||||
def test_extract_multiple_markdown_entries(tmp_path):
|
||||
"Convert multiple markdown from single file to entry format."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
### Heading 1
|
||||
@@ -81,19 +68,139 @@ def test_multiple_markdown_entries_to_jsonl(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entry_strings, entry_to_file_map = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
||||
entries = MarkdownToEntries.convert_markdown_entries_to_maps(entry_strings, entry_to_file_map)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = MarkdownToEntries.convert_markdown_maps_to_jsonl(entries)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 2
|
||||
assert len(entries) == 2
|
||||
# Ensure entry compiled strings include the markdown files they originate from
|
||||
assert all([tmp_path.stem in entry.compiled for entry in entries])
|
||||
|
||||
|
||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
"Extract markdown entries with different level headings."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
# Heading 1
|
||||
## Sub-Heading 1.1
|
||||
# Heading 2
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert entries[0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory"
|
||||
assert entries[1].raw == "# Heading 2\n"
|
||||
|
||||
|
||||
def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
|
||||
"Extract markdown entries when deeper child level before shallower child level."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
# Heading 1
|
||||
#### Sub-Heading 1.1
|
||||
## Sub-Heading 1.2
|
||||
# Heading 2
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 3
|
||||
assert entries[0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory"
|
||||
assert entries[1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory"
|
||||
assert entries[2].raw == "# Heading 2\n"
|
||||
|
||||
|
||||
def test_extract_entries_with_text_before_headings(tmp_path):
|
||||
"Extract markdown entries with some text before any headings."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
Text before headings
|
||||
# Heading 1
|
||||
body line 1
|
||||
## Heading 2
|
||||
body line 2
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 3
|
||||
assert entries[0].raw == "\nText before headings"
|
||||
assert entries[1].raw == "# Heading 1\nbody line 1"
|
||||
assert entries[2].raw == "# Heading 1\n## Heading 2\nbody line 2\n", "Ensure raw entry includes heading ancestory"
|
||||
|
||||
|
||||
def test_parse_markdown_file_into_single_entry_if_small(tmp_path):
|
||||
"Parse markdown file into single entry if it fits within the token limits."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
# Heading 1
|
||||
body line 1
|
||||
## Subheading 1.1
|
||||
body line 1.1
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 1
|
||||
assert entries[0].raw == entry
|
||||
|
||||
|
||||
def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path):
|
||||
"Parse markdown entry with child headings as single entry if it fits within the tokens limits."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
# Heading 1
|
||||
body line 1
|
||||
## Subheading 1.1
|
||||
body line 1.1
|
||||
# Heading 2
|
||||
body line 2
|
||||
## Subheading 2.1
|
||||
longer body line 2.1
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 3
|
||||
assert (
|
||||
entries[0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
|
||||
), "First entry includes children headings"
|
||||
assert entries[1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
|
||||
assert (
|
||||
entries[2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
|
||||
), "Third entry is second entries child heading"
|
||||
|
||||
|
||||
def test_get_markdown_files(tmp_path):
|
||||
"Ensure Markdown files specified via input-filter, input-files extracted"
|
||||
# Arrange
|
||||
@@ -131,27 +238,6 @@ def test_get_markdown_files(tmp_path):
|
||||
assert set(extracted_org_files.keys()) == expected_files
|
||||
|
||||
|
||||
def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
"Extract markdown entries with different level headings."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
# Heading 1
|
||||
## Heading 2
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Markdown files
|
||||
entries, _ = MarkdownToEntries.extract_markdown_entries(markdown_files=data)
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert entries[0] == "# Heading 1"
|
||||
assert entries[1] == "## Heading 2"
|
||||
|
||||
|
||||
# Helper Functions
|
||||
def create_file(tmp_path: Path, entry=None, filename="test.md"):
|
||||
markdown_file = tmp_path / filename
|
||||
|
||||
@@ -56,7 +56,7 @@ def test_index_update_with_user2_inaccessible_user1(client, api_user2: KhojApiUs
|
||||
|
||||
# Assert
|
||||
assert update_response.status_code == 200
|
||||
assert len(results) == 5
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
assert result["additional"]["file"] not in source_file_symbol
|
||||
|
||||
|
||||
@@ -470,10 +470,6 @@ async def test_websearch_with_operators(chat_client):
|
||||
["site:reddit.com" in response for response in responses]
|
||||
), "Expected a search query to include site:reddit.com but got: " + str(responses)
|
||||
|
||||
assert any(
|
||||
["after:2024/04/01" in response for response in responses]
|
||||
), "Expected a search query to include after:2024/04/01 but got: " + str(responses)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
@@ -8,7 +8,7 @@ from khoj.utils.helpers import is_none_or_empty
|
||||
from khoj.utils.rawconfig import Entry, TextContentConfig
|
||||
|
||||
|
||||
def test_configure_heading_entry_to_jsonl(tmp_path):
|
||||
def test_configure_indexing_heading_only_entries(tmp_path):
|
||||
"""Ensure entries with empty body are ignored, unless explicitly configured to index heading entries.
|
||||
Property drawers not considered Body. Ignore control characters for evaluating if Body empty."""
|
||||
# Arrange
|
||||
@@ -26,24 +26,21 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
|
||||
for index_heading_entries in [True, False]:
|
||||
# Act
|
||||
# Extract entries into jsonl from specified Org files
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
|
||||
OrgToEntries.convert_org_nodes_to_entries(
|
||||
*OrgToEntries.extract_org_entries(org_files=data), index_heading_entries=index_heading_entries
|
||||
)
|
||||
entries = OrgToEntries.extract_org_entries(
|
||||
org_files=data, index_heading_entries=index_heading_entries, max_tokens=3
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
if index_heading_entries:
|
||||
# Entry with empty body indexed when index_heading_entries set to True
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
else:
|
||||
# Entry with empty body ignored when index_heading_entries set to False
|
||||
assert is_none_or_empty(jsonl_data)
|
||||
assert is_none_or_empty(entries)
|
||||
|
||||
|
||||
def test_entry_split_when_exceeds_max_words():
|
||||
"Ensure entries with compiled words exceeding max_words are split."
|
||||
def test_entry_split_when_exceeds_max_tokens():
|
||||
"Ensure entries with compiled words exceeding max_tokens are split."
|
||||
# Arrange
|
||||
tmp_path = "/tmp/test.org"
|
||||
entry = f"""*** Heading
|
||||
@@ -57,29 +54,26 @@ def test_entry_split_when_exceeds_max_words():
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Split each entry from specified Org files by max words
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
|
||||
TextToEntries.split_entries_by_max_tokens(
|
||||
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
|
||||
)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
# Split each entry from specified Org files by max tokens
|
||||
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=6)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 2
|
||||
# Ensure compiled entries split by max_words start with entry heading (for search context)
|
||||
assert all([entry["compiled"].startswith(expected_heading) for entry in jsonl_data])
|
||||
assert len(entries) == 2
|
||||
# Ensure compiled entries split by max tokens start with entry heading (for search context)
|
||||
assert all([entry.compiled.startswith(expected_heading) for entry in entries])
|
||||
|
||||
|
||||
def test_entry_split_drops_large_words():
|
||||
"Ensure entries drops words larger than specified max word length from compiled version."
|
||||
# Arrange
|
||||
entry_text = f"""*** Heading
|
||||
\t\r
|
||||
Body Line 1
|
||||
"""
|
||||
entry_text = f"""First Line
|
||||
dog=1\n\r\t
|
||||
cat=10
|
||||
car=4
|
||||
book=2
|
||||
"""
|
||||
entry = Entry(raw=entry_text, compiled=entry_text)
|
||||
|
||||
# Act
|
||||
@@ -87,11 +81,158 @@ def test_entry_split_drops_large_words():
|
||||
processed_entry = TextToEntries.split_entries_by_max_tokens([entry], max_word_length=5)[0]
|
||||
|
||||
# Assert
|
||||
# "Heading" dropped from compiled version because its over the set max word limit
|
||||
assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 1
|
||||
# Ensure words larger than max word length are dropped
|
||||
# Ensure newline characters are considered as word boundaries for splitting words. See #620
|
||||
words_to_keep = ["First", "Line", "dog=1", "car=4"]
|
||||
words_to_drop = ["cat=10", "book=2"]
|
||||
assert all([word for word in words_to_keep if word in processed_entry.compiled])
|
||||
assert not any([word for word in words_to_drop if word in processed_entry.compiled])
|
||||
assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 2
|
||||
|
||||
|
||||
def test_entry_with_body_to_jsonl(tmp_path):
|
||||
def test_parse_org_file_into_single_entry_if_small(tmp_path):
|
||||
"Parse org file into single entry if it fits within the token limits."
|
||||
# Arrange
|
||||
original_entry = f"""
|
||||
* Heading 1
|
||||
body line 1
|
||||
** Subheading 1.1
|
||||
body line 1.1
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": original_entry,
|
||||
}
|
||||
expected_entry = f"""
|
||||
* Heading 1
|
||||
body line 1
|
||||
|
||||
** Subheading 1.1
|
||||
body line 1.1
|
||||
|
||||
""".lstrip()
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
|
||||
for entry in extracted_entries:
|
||||
entry.raw = clean(entry.raw)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_entries) == 1
|
||||
assert entry.raw == expected_entry
|
||||
|
||||
|
||||
def test_parse_org_entry_with_children_as_single_entry_if_small(tmp_path):
|
||||
"Parse org entry with child headings as single entry only if it fits within the tokens limits."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
* Heading 1
|
||||
body line 1
|
||||
** Subheading 1.1
|
||||
body line 1.1
|
||||
* Heading 2
|
||||
body line 2
|
||||
** Subheading 2.1
|
||||
longer body line 2.1
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
first_expected_entry = f"""
|
||||
* Path: {tmp_path}
|
||||
** Heading 1.
|
||||
body line 1
|
||||
|
||||
*** Subheading 1.1.
|
||||
body line 1.1
|
||||
|
||||
""".lstrip()
|
||||
second_expected_entry = f"""
|
||||
* Path: {tmp_path}
|
||||
** Heading 2.
|
||||
body line 2
|
||||
|
||||
""".lstrip()
|
||||
third_expected_entry = f"""
|
||||
* Path: {tmp_path} / Heading 2
|
||||
** Subheading 2.1.
|
||||
longer body line 2.1
|
||||
|
||||
""".lstrip()
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=12)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_entries) == 3
|
||||
assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
|
||||
assert extracted_entries[1].compiled == second_expected_entry, "Second entry does not include children headings"
|
||||
assert extracted_entries[2].compiled == third_expected_entry, "Third entry is second entries child heading"
|
||||
|
||||
|
||||
def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path):
|
||||
"Parse org sibling entries as separate entries only if it fits within the tokens limits."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
* Heading 1
|
||||
body line 1
|
||||
** Subheading 1.1
|
||||
body line 1.1
|
||||
* Heading 2
|
||||
body line 2
|
||||
** Subheading 2.1
|
||||
body line 2.1
|
||||
* Heading 3
|
||||
body line 3
|
||||
** Subheading 3.1
|
||||
body line 3.1
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
}
|
||||
first_expected_entry = f"""
|
||||
* Path: {tmp_path}
|
||||
** Heading 1.
|
||||
body line 1
|
||||
|
||||
*** Subheading 1.1.
|
||||
body line 1.1
|
||||
|
||||
""".lstrip()
|
||||
second_expected_entry = f"""
|
||||
* Path: {tmp_path}
|
||||
** Heading 2.
|
||||
body line 2
|
||||
|
||||
*** Subheading 2.1.
|
||||
body line 2.1
|
||||
|
||||
""".lstrip()
|
||||
third_expected_entry = f"""
|
||||
* Path: {tmp_path}
|
||||
** Heading 3.
|
||||
body line 3
|
||||
|
||||
*** Subheading 3.1.
|
||||
body line 3.1
|
||||
|
||||
""".lstrip()
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
# Max tokens = 30 is in the middle of 2 entry (24 tokens) and 3 entry (36 tokens) tokens boundary
|
||||
# Where each sibling entry contains 12 tokens per sibling entry * 3 entries = 36 tokens
|
||||
extracted_entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=30)
|
||||
|
||||
# Assert
|
||||
assert len(extracted_entries) == 3
|
||||
assert extracted_entries[0].compiled == first_expected_entry, "First entry includes children headings"
|
||||
assert extracted_entries[1].compiled == second_expected_entry, "Second entry includes children headings"
|
||||
assert extracted_entries[2].compiled == third_expected_entry, "Third entry includes children headings"
|
||||
|
||||
|
||||
def test_entry_with_body_to_entry(tmp_path):
|
||||
"Ensure entries with valid body text are loaded."
|
||||
# Arrange
|
||||
entry = f"""*** Heading
|
||||
@@ -107,19 +248,13 @@ def test_entry_with_body_to_jsonl(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, entry_to_file_map = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(
|
||||
OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
|
||||
|
||||
def test_file_with_entry_after_intro_text_to_jsonl(tmp_path):
|
||||
def test_file_with_entry_after_intro_text_to_entry(tmp_path):
|
||||
"Ensure intro text before any headings is indexed."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
@@ -134,18 +269,13 @@ Intro text
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 2
|
||||
assert len(entries) == 2
|
||||
|
||||
|
||||
def test_file_with_no_headings_to_jsonl(tmp_path):
|
||||
def test_file_with_no_headings_to_entry(tmp_path):
|
||||
"Ensure files with no heading, only body text are loaded."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
@@ -158,15 +288,10 @@ def test_file_with_no_headings_to_jsonl(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entry_nodes, file_to_entries = OrgToEntries.extract_org_entries(org_files=data)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
entries = OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries)
|
||||
jsonl_string = OrgToEntries.convert_org_entries_to_jsonl(entries)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
|
||||
|
||||
def test_get_org_files(tmp_path):
|
||||
@@ -214,7 +339,8 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
# Arrange
|
||||
entry = f"""
|
||||
* Heading 1
|
||||
** Heading 2
|
||||
** Sub-Heading 1.1
|
||||
* Heading 2
|
||||
"""
|
||||
data = {
|
||||
f"{tmp_path}": entry,
|
||||
@@ -222,12 +348,14 @@ def test_extract_entries_with_different_level_headings(tmp_path):
|
||||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries, _ = OrgToEntries.extract_org_entries(org_files=data)
|
||||
entries = OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=True, max_tokens=3)
|
||||
for entry in entries:
|
||||
entry.raw = clean(f"{entry.raw}")
|
||||
|
||||
# Assert
|
||||
assert len(entries) == 2
|
||||
assert f"{entries[0]}".startswith("* Heading 1")
|
||||
assert f"{entries[1]}".startswith("** Heading 2")
|
||||
assert entries[0].raw == "* Heading 1\n** Sub-Heading 1.1\n", "Ensure entry includes heading ancestory"
|
||||
assert entries[1].raw == "* Heading 2\n"
|
||||
|
||||
|
||||
# Helper Functions
|
||||
@@ -237,3 +365,8 @@ def create_file(tmp_path, entry=None, filename="test.org"):
|
||||
if entry:
|
||||
org_file.write_text(entry)
|
||||
return org_file
|
||||
|
||||
|
||||
def clean(entry):
|
||||
"Remove properties from entry for easier comparison."
|
||||
return re.sub(r"\n:PROPERTIES:(.*?):END:", "", entry, flags=re.DOTALL)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||
@@ -15,16 +14,10 @@ def test_single_page_pdf_to_jsonl():
|
||||
pdf_bytes = f.read()
|
||||
|
||||
data = {"tests/data/pdf/singlepage.pdf": pdf_bytes}
|
||||
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Process Each Entry from All Pdf Files
|
||||
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
|
||||
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
|
||||
|
||||
def test_multi_page_pdf_to_jsonl():
|
||||
@@ -35,16 +28,10 @@ def test_multi_page_pdf_to_jsonl():
|
||||
pdf_bytes = f.read()
|
||||
|
||||
data = {"tests/data/pdf/multipage.pdf": pdf_bytes}
|
||||
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Process Each Entry from All Pdf Files
|
||||
jsonl_string = PdfToEntries.convert_pdf_maps_to_jsonl(
|
||||
PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
||||
)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 6
|
||||
assert len(entries) == 6
|
||||
|
||||
|
||||
def test_ocr_page_pdf_to_jsonl():
|
||||
@@ -55,10 +42,7 @@ def test_ocr_page_pdf_to_jsonl():
|
||||
pdf_bytes = f.read()
|
||||
|
||||
data = {"tests/data/pdf/ocr_samples.pdf": pdf_bytes}
|
||||
entries, entry_to_file_map = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
# Process Each Entry from All Pdf Files
|
||||
entries = PdfToEntries.convert_pdf_entries_to_maps(entries, entry_to_file_map)
|
||||
entries = PdfToEntries.extract_pdf_entries(pdf_files=data)
|
||||
|
||||
assert len(entries) == 1
|
||||
assert "playing on a strip of marsh" in entries[0].raw
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -11,10 +10,10 @@ from khoj.utils.rawconfig import TextContentConfig
|
||||
def test_plaintext_file(tmp_path):
|
||||
"Convert files with no heading to jsonl."
|
||||
# Arrange
|
||||
entry = f"""
|
||||
raw_entry = f"""
|
||||
Hi, I am a plaintext file and I have some plaintext words.
|
||||
"""
|
||||
plaintextfile = create_file(tmp_path, entry)
|
||||
plaintextfile = create_file(tmp_path, raw_entry)
|
||||
|
||||
filename = plaintextfile.stem
|
||||
|
||||
@@ -22,25 +21,21 @@ def test_plaintext_file(tmp_path):
|
||||
# Extract Entries from specified plaintext files
|
||||
|
||||
data = {
|
||||
f"{plaintextfile}": entry,
|
||||
f"{plaintextfile}": raw_entry,
|
||||
}
|
||||
|
||||
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(entry_to_file_map=data)
|
||||
entries = PlaintextToEntries.extract_plaintext_entries(entry_to_file_map=data)
|
||||
|
||||
# Convert each entry.file to absolute path to make them JSON serializable
|
||||
for map in maps:
|
||||
map.file = str(Path(map.file).absolute())
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = PlaintextToEntries.convert_entries_to_jsonl(maps)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
for entry in entries:
|
||||
entry.file = str(Path(entry.file).absolute())
|
||||
|
||||
# Assert
|
||||
assert len(jsonl_data) == 1
|
||||
assert len(entries) == 1
|
||||
# Ensure raw entry with no headings do not get heading prefix prepended
|
||||
assert not jsonl_data[0]["raw"].startswith("#")
|
||||
assert not entries[0].raw.startswith("#")
|
||||
# Ensure compiled entry has filename prepended as top level heading
|
||||
assert jsonl_data[0]["compiled"] == f"{filename}\n{entry}"
|
||||
assert entries[0].compiled == f"{filename}\n{raw_entry}"
|
||||
|
||||
|
||||
def test_get_plaintext_files(tmp_path):
|
||||
@@ -98,11 +93,11 @@ def test_parse_html_plaintext_file(content_config, default_user: KhojUser):
|
||||
extracted_plaintext_files = get_plaintext_files(config=config)
|
||||
|
||||
# Act
|
||||
maps = PlaintextToEntries.convert_plaintext_entries_to_maps(extracted_plaintext_files)
|
||||
entries = PlaintextToEntries.extract_plaintext_entries(extracted_plaintext_files)
|
||||
|
||||
# Assert
|
||||
assert len(maps) == 1
|
||||
assert "<div>" not in maps[0].raw
|
||||
assert len(entries) == 1
|
||||
assert "<div>" not in entries[0].raw
|
||||
|
||||
|
||||
# Helper Functions
|
||||
|
||||
@@ -57,18 +57,21 @@ def test_get_org_files_with_org_suffixed_dir_doesnt_raise_error(tmp_path, defaul
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_with_empty_file_creates_no_entries(
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
|
||||
):
|
||||
# Arrange
|
||||
existing_entries = Entry.objects.filter(user=default_user).count()
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert "Deleted 8 entries. Created 0 new entries for user " in caplog.records[-1].message
|
||||
updated_entries = Entry.objects.filter(user=default_user).count()
|
||||
|
||||
assert existing_entries == 2
|
||||
assert updated_entries == 0
|
||||
verify_embeddings(0, default_user)
|
||||
|
||||
|
||||
@@ -78,6 +81,7 @@ def test_text_indexer_deletes_embedding_before_regenerate(
|
||||
content_config: ContentConfig, default_user: KhojUser, caplog
|
||||
):
|
||||
# Arrange
|
||||
existing_entries = Entry.objects.filter(user=default_user).count()
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
@@ -87,30 +91,18 @@ def test_text_indexer_deletes_embedding_before_regenerate(
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
updated_entries = Entry.objects.filter(user=default_user).count()
|
||||
assert existing_entries == 2
|
||||
assert updated_entries == 2
|
||||
assert "Deleting all entries for file type org" in caplog.text
|
||||
assert "Deleted 8 entries. Created 13 new entries for user " in caplog.records[-1].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_search_setup_batch_processes(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
# Generate notes embeddings during asymmetric setup
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# Assert
|
||||
assert "Deleted 8 entries. Created 13 new entries for user " in caplog.records[-1].message
|
||||
assert "Deleted 2 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_text_index_same_if_content_unchanged(content_config: ContentConfig, default_user: KhojUser, caplog):
|
||||
# Arrange
|
||||
existing_entries = Entry.objects.filter(user=default_user)
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
|
||||
@@ -127,6 +119,10 @@ def test_text_index_same_if_content_unchanged(content_config: ContentConfig, def
|
||||
final_logs = caplog.text
|
||||
|
||||
# Assert
|
||||
updated_entries = Entry.objects.filter(user=default_user)
|
||||
for entry in updated_entries:
|
||||
assert entry in existing_entries
|
||||
assert len(existing_entries) == len(updated_entries)
|
||||
assert "Deleting all entries for file type org" in initial_logs
|
||||
assert "Deleting all entries for file type org" not in final_logs
|
||||
|
||||
@@ -192,7 +188,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
|
||||
|
||||
# Assert
|
||||
assert (
|
||||
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
|
||||
), "new entry not split by max tokens"
|
||||
|
||||
|
||||
@@ -250,16 +246,15 @@ conda activate khoj
|
||||
|
||||
# Assert
|
||||
assert (
|
||||
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message
|
||||
"Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
|
||||
), "new entry not split by max tokens"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_regenerate_index_with_new_entry(
|
||||
content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog
|
||||
):
|
||||
def test_regenerate_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
|
||||
# Arrange
|
||||
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
initial_data = get_org_files(org_config)
|
||||
|
||||
@@ -271,28 +266,34 @@ def test_regenerate_index_with_new_entry(
|
||||
final_data = get_org_files(org_config)
|
||||
|
||||
# Act
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
|
||||
final_logs = caplog.text
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=True, user=default_user)
|
||||
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# Assert
|
||||
assert "Deleted 8 entries. Created 13 new entries for user " in initial_logs
|
||||
assert "Deleted 13 entries. Created 14 new entries for user " in final_logs
|
||||
verify_embeddings(14, default_user)
|
||||
for entry in updated_entries1:
|
||||
assert entry in updated_entries2
|
||||
|
||||
assert not any([new_org_file.name in entry for entry in updated_entries1])
|
||||
assert not any([new_org_file.name in entry for entry in existing_entries])
|
||||
assert any([new_org_file.name in entry for entry in updated_entries2])
|
||||
|
||||
assert any(
|
||||
["Saw a super cute video of a chihuahua doing the Tango on Youtube" in entry for entry in updated_entries2]
|
||||
)
|
||||
verify_embeddings(3, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_duplicate_entries_in_stable_order(
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog
|
||||
org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser
|
||||
):
|
||||
# Arrange
|
||||
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
||||
# Insert org-mode entries with same compiled form into new org file
|
||||
@@ -304,30 +305,33 @@ def test_update_index_with_duplicate_entries_in_stable_order(
|
||||
|
||||
# Act
|
||||
# generate embeddings, entries, notes model from scratch after adding new org-mode file
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
data = get_org_files(org_config_with_only_new_file)
|
||||
|
||||
# update embeddings, entries, notes model with no new changes
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert "Deleted 8 entries. Created 1 new entries for user " in initial_logs
|
||||
assert "Deleted 0 entries. Created 0 new entries for user " in final_logs
|
||||
for entry in existing_entries:
|
||||
assert entry not in updated_entries1
|
||||
|
||||
for entry in updated_entries1:
|
||||
assert entry in updated_entries2
|
||||
|
||||
assert len(existing_entries) == 2
|
||||
assert len(updated_entries1) == len(updated_entries2)
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser, caplog):
|
||||
def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrgConfig, default_user: KhojUser):
|
||||
# Arrange
|
||||
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
new_file_to_index = Path(org_config_with_only_new_file.input_files[0])
|
||||
|
||||
# Insert org-mode entries with same compiled form into new org file
|
||||
@@ -344,33 +348,34 @@ def test_update_index_with_deleted_entry(org_config_with_only_new_file: LocalOrg
|
||||
|
||||
# Act
|
||||
# load embeddings, entries, notes model after adding new org file with 2 entries
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
text_search.setup(OrgToEntries, initial_data, regenerate=True, user=default_user)
|
||||
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
text_search.setup(OrgToEntries, final_data, regenerate=False, user=default_user)
|
||||
updated_entries2 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# Assert
|
||||
# verify only 1 entry added even if there are multiple duplicate entries
|
||||
assert "Deleted 8 entries. Created 2 new entries for user " in initial_logs
|
||||
assert "Deleted 1 entries. Created 0 new entries for user " in final_logs
|
||||
for entry in existing_entries:
|
||||
assert entry not in updated_entries1
|
||||
|
||||
# verify the entry in updated_entries2 is a subset of updated_entries1
|
||||
for entry in updated_entries1:
|
||||
assert entry not in updated_entries2
|
||||
|
||||
for entry in updated_entries2:
|
||||
assert entry in updated_entries1[0]
|
||||
|
||||
verify_embeddings(1, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser, caplog):
|
||||
def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file: Path, default_user: KhojUser):
|
||||
# Arrange
|
||||
existing_entries = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
org_config = LocalOrgConfig.objects.filter(user=default_user).first()
|
||||
data = get_org_files(org_config)
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
initial_logs = caplog.text
|
||||
caplog.clear() # Clear logs
|
||||
text_search.setup(OrgToEntries, data, regenerate=True, user=default_user)
|
||||
|
||||
# append org-mode entry to first org input file in config
|
||||
with open(new_org_file, "w") as f:
|
||||
@@ -381,14 +386,14 @@ def test_update_index_with_new_entry(content_config: ContentConfig, new_org_file
|
||||
|
||||
# Act
|
||||
# update embeddings, entries with the newly added note
|
||||
with caplog.at_level(logging.INFO):
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
final_logs = caplog.text
|
||||
text_search.setup(OrgToEntries, data, regenerate=False, user=default_user)
|
||||
updated_entries1 = list(Entry.objects.filter(user=default_user).values_list("compiled", flat=True))
|
||||
|
||||
# Assert
|
||||
assert "Deleted 8 entries. Created 13 new entries for user " in initial_logs
|
||||
assert "Deleted 0 entries. Created 1 new entries for user " in final_logs
|
||||
verify_embeddings(14, default_user)
|
||||
for entry in existing_entries:
|
||||
assert entry not in updated_entries1
|
||||
assert len(updated_entries1) == len(existing_entries) + 1
|
||||
verify_embeddings(3, default_user)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user