mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Update Text Chunking Strategy to Improve Search Context (#645)
## Major - Parse markdown, org parent entries as single entry if fit within max tokens - Parse a file as single entry if it fits with max token limits - Add parent heading ancestry to extracted markdown entries for context - Chunk text in preference order of para, sentence, word, character ## Minor - Create wrapper function to get entries from org, md, pdf & text files - Remove unused Entry to Jsonl converter from text to entry class, tests - Dedupe code by using single func to process an org file into entries Resolves #620
This commit is contained in:
@@ -1,14 +1,13 @@
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import urllib3
|
||||
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import Entry
|
||||
|
||||
@@ -31,15 +30,14 @@ class MarkdownToEntries(TextToEntries):
|
||||
else:
|
||||
deletion_file_names = None
|
||||
|
||||
max_tokens = 256
|
||||
# Extract Entries from specified Markdown files
|
||||
with timer("Parse entries from Markdown files into dictionaries", logger):
|
||||
current_entries = MarkdownToEntries.convert_markdown_entries_to_maps(
|
||||
*MarkdownToEntries.extract_markdown_entries(files)
|
||||
)
|
||||
with timer("Extract entries from specified Markdown files", logger):
|
||||
current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
@@ -57,48 +55,84 @@ class MarkdownToEntries(TextToEntries):
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_markdown_entries(markdown_files):
|
||||
def extract_markdown_entries(markdown_files, max_tokens=256) -> List[Entry]:
|
||||
"Extract entries by heading from specified Markdown files"
|
||||
|
||||
# Regex to extract Markdown Entries by Heading
|
||||
|
||||
entries = []
|
||||
entry_to_file_map = []
|
||||
entries: List[str] = []
|
||||
entry_to_file_map: List[Tuple[str, str]] = []
|
||||
for markdown_file in markdown_files:
|
||||
try:
|
||||
markdown_content = markdown_files[markdown_file]
|
||||
entries, entry_to_file_map = MarkdownToEntries.process_single_markdown_file(
|
||||
markdown_content, markdown_file, entries, entry_to_file_map
|
||||
markdown_content, markdown_file, entries, entry_to_file_map, max_tokens
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to process file: {markdown_file}. This file will not be indexed.")
|
||||
logger.warning(e, exc_info=True)
|
||||
logger.error(
|
||||
f"Unable to process file: {markdown_file}. This file will not be indexed.\n{e}", exc_info=True
|
||||
)
|
||||
|
||||
return entries, dict(entry_to_file_map)
|
||||
return MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
|
||||
|
||||
@staticmethod
|
||||
def process_single_markdown_file(
|
||||
markdown_content: str, markdown_file: Path, entries: List, entry_to_file_map: List
|
||||
):
|
||||
markdown_heading_regex = r"^#"
|
||||
markdown_content: str,
|
||||
markdown_file: str,
|
||||
entries: List[str],
|
||||
entry_to_file_map: List[Tuple[str, str]],
|
||||
max_tokens=256,
|
||||
ancestry: Dict[int, str] = {},
|
||||
) -> Tuple[List[str], List[Tuple[str, str]]]:
|
||||
# Prepend the markdown section's heading ancestry
|
||||
ancestry_string = "\n".join([f"{'#' * key} {ancestry[key]}" for key in sorted(ancestry.keys())])
|
||||
markdown_content_with_ancestry = f"{ancestry_string}{markdown_content}"
|
||||
|
||||
markdown_entries_per_file = []
|
||||
any_headings = re.search(markdown_heading_regex, markdown_content, flags=re.MULTILINE)
|
||||
for entry in re.split(markdown_heading_regex, markdown_content, flags=re.MULTILINE):
|
||||
# Add heading level as the regex split removed it from entries with headings
|
||||
prefix = "#" if entry.startswith("#") else "# " if any_headings else ""
|
||||
stripped_entry = entry.strip(empty_escape_sequences)
|
||||
if stripped_entry != "":
|
||||
markdown_entries_per_file.append(f"{prefix}{stripped_entry}")
|
||||
# If content is small or content has no children headings, save it as a single entry
|
||||
if len(TextToEntries.tokenizer(markdown_content_with_ancestry)) <= max_tokens or not re.search(
|
||||
rf"^#{{{len(ancestry)+1},}}\s", markdown_content, flags=re.MULTILINE
|
||||
):
|
||||
entry_to_file_map += [(markdown_content_with_ancestry, markdown_file)]
|
||||
entries.extend([markdown_content_with_ancestry])
|
||||
return entries, entry_to_file_map
|
||||
|
||||
# Split by next heading level present in the entry
|
||||
next_heading_level = len(ancestry)
|
||||
sections: List[str] = []
|
||||
while len(sections) < 2:
|
||||
next_heading_level += 1
|
||||
sections = re.split(rf"(\n|^)(?=[#]{{{next_heading_level}}} .+\n?)", markdown_content, flags=re.MULTILINE)
|
||||
|
||||
for section in sections:
|
||||
# Skip empty sections
|
||||
if section.strip() == "":
|
||||
continue
|
||||
|
||||
# Extract the section body and (when present) the heading
|
||||
current_ancestry = ancestry.copy()
|
||||
first_line = [line for line in section.split("\n") if line.strip() != ""][0]
|
||||
if re.search(rf"^#{{{next_heading_level}}} ", first_line):
|
||||
# Extract the section body without the heading
|
||||
current_section_body = "\n".join(section.split(first_line)[1:])
|
||||
# Parse the section heading into current section ancestry
|
||||
current_section_title = first_line[next_heading_level:].strip()
|
||||
current_ancestry[next_heading_level] = current_section_title
|
||||
else:
|
||||
current_section_body = section
|
||||
|
||||
# Recurse down children of the current entry
|
||||
MarkdownToEntries.process_single_markdown_file(
|
||||
current_section_body,
|
||||
markdown_file,
|
||||
entries,
|
||||
entry_to_file_map,
|
||||
max_tokens,
|
||||
current_ancestry,
|
||||
)
|
||||
|
||||
entry_to_file_map += zip(markdown_entries_per_file, [markdown_file] * len(markdown_entries_per_file))
|
||||
entries.extend(markdown_entries_per_file)
|
||||
return entries, entry_to_file_map
|
||||
|
||||
@staticmethod
|
||||
def convert_markdown_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
|
||||
"Convert each Markdown entries into a dictionary"
|
||||
entries = []
|
||||
entries: List[Entry] = []
|
||||
for parsed_entry in parsed_entries:
|
||||
raw_filename = entry_to_file_map[parsed_entry]
|
||||
|
||||
@@ -108,13 +142,12 @@ class MarkdownToEntries(TextToEntries):
|
||||
entry_filename = urllib3.util.parse_url(raw_filename).url
|
||||
else:
|
||||
entry_filename = str(Path(raw_filename))
|
||||
stem = Path(raw_filename).stem
|
||||
|
||||
heading = parsed_entry.splitlines()[0] if re.search("^#+\s", parsed_entry) else ""
|
||||
# Append base filename to compiled entry for context to model
|
||||
# Increment heading level for heading entries and make filename as its top level heading
|
||||
prefix = f"# {stem}\n#" if heading else f"# {stem}\n"
|
||||
compiled_entry = f"{entry_filename}\n{prefix}{parsed_entry}"
|
||||
prefix = f"# {entry_filename}\n#" if heading else f"# {entry_filename}\n"
|
||||
compiled_entry = f"{prefix}{parsed_entry}"
|
||||
entries.append(
|
||||
Entry(
|
||||
compiled=compiled_entry,
|
||||
@@ -127,8 +160,3 @@ class MarkdownToEntries(TextToEntries):
|
||||
logger.debug(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_markdown_maps_to_jsonl(entries: List[Entry]):
|
||||
"Convert each Markdown entry to JSON and collate as JSONL"
|
||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.content.org_mode import orgnode
|
||||
from khoj.processor.content.org_mode.orgnode import Orgnode
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import timer
|
||||
@@ -21,9 +23,6 @@ class OrgToEntries(TextToEntries):
|
||||
def process(
|
||||
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
|
||||
) -> Tuple[int, int]:
|
||||
# Extract required fields from config
|
||||
index_heading_entries = False
|
||||
|
||||
if not full_corpus:
|
||||
deletion_file_names = set([file for file in files if files[file] == ""])
|
||||
files_to_process = set(files) - deletion_file_names
|
||||
@@ -32,14 +31,12 @@ class OrgToEntries(TextToEntries):
|
||||
deletion_file_names = None
|
||||
|
||||
# Extract Entries from specified Org files
|
||||
with timer("Parse entries from org files into OrgNode objects", logger):
|
||||
entry_nodes, file_to_entries = self.extract_org_entries(files)
|
||||
|
||||
with timer("Convert OrgNodes into list of entries", logger):
|
||||
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
|
||||
max_tokens = 256
|
||||
with timer("Extract entries from specified Org files", logger):
|
||||
current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
|
||||
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=max_tokens)
|
||||
|
||||
# Identify, mark and merge any new entries with previous entries
|
||||
with timer("Identify new or updated entries", logger):
|
||||
@@ -57,93 +54,165 @@ class OrgToEntries(TextToEntries):
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_org_entries(org_files: dict[str, str]):
|
||||
def extract_org_entries(
|
||||
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256
|
||||
) -> List[Entry]:
|
||||
"Extract entries from specified Org files"
|
||||
entries = []
|
||||
entry_to_file_map: List[Tuple[orgnode.Orgnode, str]] = []
|
||||
entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
|
||||
return OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map, index_heading_entries)
|
||||
|
||||
@staticmethod
|
||||
def extract_org_nodes(org_files: dict[str, str], max_tokens) -> Tuple[List[List[Orgnode]], Dict[Orgnode, str]]:
|
||||
"Extract org nodes from specified org files"
|
||||
entries: List[List[Orgnode]] = []
|
||||
entry_to_file_map: List[Tuple[Orgnode, str]] = []
|
||||
for org_file in org_files:
|
||||
filename = org_file
|
||||
file = org_files[org_file]
|
||||
try:
|
||||
org_file_entries = orgnode.makelist(file, filename)
|
||||
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
|
||||
entries.extend(org_file_entries)
|
||||
org_content = org_files[org_file]
|
||||
entries, entry_to_file_map = OrgToEntries.process_single_org_file(
|
||||
org_content, org_file, entries, entry_to_file_map, max_tokens
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to process file: {org_file}. This file will not be indexed.")
|
||||
logger.warning(e, exc_info=True)
|
||||
logger.error(f"Unable to process file: {org_file}. Skipped indexing it.\nError; {e}", exc_info=True)
|
||||
|
||||
return entries, dict(entry_to_file_map)
|
||||
|
||||
@staticmethod
|
||||
def process_single_org_file(org_content: str, org_file: str, entries: List, entry_to_file_map: List):
|
||||
# Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer. We'll split the raw conetnt of this file by new line to mimic the same behavior.
|
||||
try:
|
||||
org_file_entries = orgnode.makelist(org_content, org_file)
|
||||
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
|
||||
entries.extend(org_file_entries)
|
||||
return entries, entry_to_file_map
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file: {org_file} with error: {e}", exc_info=True)
|
||||
def process_single_org_file(
|
||||
org_content: str,
|
||||
org_file: str,
|
||||
entries: List[List[Orgnode]],
|
||||
entry_to_file_map: List[Tuple[Orgnode, str]],
|
||||
max_tokens=256,
|
||||
ancestry: Dict[int, str] = {},
|
||||
) -> Tuple[List[List[Orgnode]], List[Tuple[Orgnode, str]]]:
|
||||
"""Parse org_content from org_file into OrgNode entries
|
||||
|
||||
Recurse down org file entries, one heading level at a time,
|
||||
until reach a leaf entry or the current entry tree fits max_tokens.
|
||||
|
||||
Parse recursion terminating entry (trees) into (a list of) OrgNode objects.
|
||||
"""
|
||||
# Prepend the org section's heading ancestry
|
||||
ancestry_string = "\n".join([f"{'*' * key} {ancestry[key]}" for key in sorted(ancestry.keys())])
|
||||
org_content_with_ancestry = f"{ancestry_string}{org_content}"
|
||||
|
||||
# If content is small or content has no children headings, save it as a single entry
|
||||
# Note: This is the terminating condition for this recursive function
|
||||
if len(TextToEntries.tokenizer(org_content_with_ancestry)) <= max_tokens or not re.search(
|
||||
rf"^\*{{{len(ancestry)+1},}}\s", org_content, re.MULTILINE
|
||||
):
|
||||
orgnode_content_with_ancestry = orgnode.makelist(org_content_with_ancestry, org_file)
|
||||
entry_to_file_map += zip(orgnode_content_with_ancestry, [org_file] * len(orgnode_content_with_ancestry))
|
||||
entries.extend([orgnode_content_with_ancestry])
|
||||
return entries, entry_to_file_map
|
||||
|
||||
# Split this entry tree into sections by the next heading level in it
|
||||
# Increment heading level until able to split entry into sections
|
||||
# A successful split will result in at least 2 sections
|
||||
next_heading_level = len(ancestry)
|
||||
sections: List[str] = []
|
||||
while len(sections) < 2:
|
||||
next_heading_level += 1
|
||||
sections = re.split(rf"(\n|^)(?=[*]{{{next_heading_level}}} .+\n?)", org_content, flags=re.MULTILINE)
|
||||
|
||||
# Recurse down each non-empty section after parsing its body, heading and ancestry
|
||||
for section in sections:
|
||||
# Skip empty sections
|
||||
if section.strip() == "":
|
||||
continue
|
||||
|
||||
# Extract the section body and (when present) the heading
|
||||
current_ancestry = ancestry.copy()
|
||||
first_non_empty_line = [line for line in section.split("\n") if line.strip() != ""][0]
|
||||
# If first non-empty line is a heading with expected heading level
|
||||
if re.search(rf"^\*{{{next_heading_level}}}\s", first_non_empty_line):
|
||||
# Extract the section body without the heading
|
||||
current_section_body = "\n".join(section.split(first_non_empty_line)[1:])
|
||||
# Parse the section heading into current section ancestry
|
||||
current_section_title = first_non_empty_line[next_heading_level:].strip()
|
||||
current_ancestry[next_heading_level] = current_section_title
|
||||
# Else process the section as just body text
|
||||
else:
|
||||
current_section_body = section
|
||||
|
||||
# Recurse down children of the current entry
|
||||
OrgToEntries.process_single_org_file(
|
||||
current_section_body,
|
||||
org_file,
|
||||
entries,
|
||||
entry_to_file_map,
|
||||
max_tokens,
|
||||
current_ancestry,
|
||||
)
|
||||
|
||||
return entries, entry_to_file_map
|
||||
|
||||
@staticmethod
|
||||
def convert_org_nodes_to_entries(
|
||||
parsed_entries: List[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False
|
||||
parsed_entries: List[List[Orgnode]],
|
||||
entry_to_file_map: Dict[Orgnode, str],
|
||||
index_heading_entries: bool = False,
|
||||
) -> List[Entry]:
|
||||
"Convert Org-Mode nodes into list of Entry objects"
|
||||
"""
|
||||
Convert OrgNode lists into list of Entry objects
|
||||
|
||||
Each list of OrgNodes is a parsed parent org tree or leaf node.
|
||||
Convert each list of these OrgNodes into a single Entry.
|
||||
"""
|
||||
entries: List[Entry] = []
|
||||
for parsed_entry in parsed_entries:
|
||||
if not parsed_entry.hasBody and not index_heading_entries:
|
||||
# Ignore title notes i.e notes with just headings and empty body
|
||||
continue
|
||||
for entry_group in parsed_entries:
|
||||
entry_heading, entry_compiled, entry_raw = "", "", ""
|
||||
for parsed_entry in entry_group:
|
||||
if not parsed_entry.hasBody and not index_heading_entries:
|
||||
# Ignore title notes i.e notes with just headings and empty body
|
||||
continue
|
||||
|
||||
todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else ""
|
||||
todo_str = f"{parsed_entry.todo} " if parsed_entry.todo else ""
|
||||
|
||||
# Prepend ancestor headings, filename as top heading to entry for context
|
||||
ancestors_trail = " / ".join(parsed_entry.ancestors) or Path(entry_to_file_map[parsed_entry])
|
||||
if parsed_entry.heading:
|
||||
heading = f"* Path: {ancestors_trail}\n** {todo_str}{parsed_entry.heading}."
|
||||
else:
|
||||
heading = f"* Path: {ancestors_trail}."
|
||||
# Set base level to current org-node tree's root heading level
|
||||
if not entry_heading and parsed_entry.level > 0:
|
||||
base_level = parsed_entry.level
|
||||
# Indent entry by 1 heading level as ancestry is prepended as top level heading
|
||||
heading = f"{'*' * (parsed_entry.level-base_level+2)} {todo_str}" if parsed_entry.level > 0 else ""
|
||||
if parsed_entry.heading:
|
||||
heading += f"{parsed_entry.heading}."
|
||||
|
||||
compiled = heading
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"Title: {heading}")
|
||||
# Prepend ancestor headings, filename as top heading to root parent entry for context
|
||||
# Children nodes do not need ancestors trail as root parent node will have it
|
||||
if not entry_heading:
|
||||
ancestors_trail = " / ".join(parsed_entry.ancestors) or Path(entry_to_file_map[parsed_entry])
|
||||
heading = f"* Path: {ancestors_trail}\n{heading}" if heading else f"* Path: {ancestors_trail}."
|
||||
|
||||
if parsed_entry.tags:
|
||||
tags_str = " ".join(parsed_entry.tags)
|
||||
compiled += f"\t {tags_str}."
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"Tags: {tags_str}")
|
||||
compiled = heading
|
||||
|
||||
if parsed_entry.closed:
|
||||
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f'Closed: {parsed_entry.closed.strftime("%Y-%m-%d")}')
|
||||
if parsed_entry.tags:
|
||||
tags_str = " ".join(parsed_entry.tags)
|
||||
compiled += f"\t {tags_str}."
|
||||
|
||||
if parsed_entry.scheduled:
|
||||
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
|
||||
if parsed_entry.closed:
|
||||
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
|
||||
|
||||
if parsed_entry.hasBody:
|
||||
compiled += f"\n {parsed_entry.body}"
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"Body: {parsed_entry.body}")
|
||||
if parsed_entry.scheduled:
|
||||
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
|
||||
|
||||
if compiled:
|
||||
if parsed_entry.hasBody:
|
||||
compiled += f"\n {parsed_entry.body}"
|
||||
|
||||
# Add the sub-entry contents to the entry
|
||||
entry_compiled += f"{compiled}"
|
||||
entry_raw += f"{parsed_entry}"
|
||||
if not entry_heading:
|
||||
entry_heading = heading
|
||||
|
||||
if entry_compiled:
|
||||
entries.append(
|
||||
Entry(
|
||||
compiled=compiled,
|
||||
raw=f"{parsed_entry}",
|
||||
heading=f"{heading}",
|
||||
compiled=entry_compiled,
|
||||
raw=entry_raw,
|
||||
heading=f"{entry_heading}",
|
||||
file=f"{entry_to_file_map[parsed_entry]}",
|
||||
)
|
||||
)
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
|
||||
"Convert each Org-Mode entry to JSON and collate as JSONL"
|
||||
return "".join([f"{entry_dict.to_json()}\n" for entry_dict in entries])
|
||||
|
||||
@@ -37,7 +37,7 @@ import datetime
|
||||
import re
|
||||
from os.path import relpath
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
indent_regex = re.compile(r"^ *")
|
||||
|
||||
@@ -58,7 +58,7 @@ def makelist_with_filepath(filename):
|
||||
return makelist(f, filename)
|
||||
|
||||
|
||||
def makelist(file, filename):
|
||||
def makelist(file, filename) -> List["Orgnode"]:
|
||||
"""
|
||||
Read an org-mode file and return a list of Orgnode objects
|
||||
created from this file.
|
||||
@@ -80,16 +80,16 @@ def makelist(file, filename):
|
||||
} # populated from #+SEQ_TODO line
|
||||
level = ""
|
||||
heading = ""
|
||||
ancestor_headings = []
|
||||
ancestor_headings: List[str] = []
|
||||
bodytext = ""
|
||||
introtext = ""
|
||||
tags = list() # set of all tags in headline
|
||||
closed_date = ""
|
||||
sched_date = ""
|
||||
deadline_date = ""
|
||||
logbook = list()
|
||||
tags: List[str] = list() # set of all tags in headline
|
||||
closed_date: datetime.date = None
|
||||
sched_date: datetime.date = None
|
||||
deadline_date: datetime.date = None
|
||||
logbook: List[Tuple[datetime.datetime, datetime.datetime]] = list()
|
||||
nodelist: List[Orgnode] = list()
|
||||
property_map = dict()
|
||||
property_map: Dict[str, str] = dict()
|
||||
in_properties_drawer = False
|
||||
in_logbook_drawer = False
|
||||
file_title = f"{filename}"
|
||||
@@ -102,13 +102,13 @@ def makelist(file, filename):
|
||||
thisNode = Orgnode(level, heading, bodytext, tags, ancestor_headings)
|
||||
if closed_date:
|
||||
thisNode.closed = closed_date
|
||||
closed_date = ""
|
||||
closed_date = None
|
||||
if sched_date:
|
||||
thisNode.scheduled = sched_date
|
||||
sched_date = ""
|
||||
sched_date = None
|
||||
if deadline_date:
|
||||
thisNode.deadline = deadline_date
|
||||
deadline_date = ""
|
||||
deadline_date = None
|
||||
if logbook:
|
||||
thisNode.logbook = logbook
|
||||
logbook = list()
|
||||
@@ -116,7 +116,7 @@ def makelist(file, filename):
|
||||
nodelist.append(thisNode)
|
||||
property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"}
|
||||
previous_level = level
|
||||
previous_heading = heading
|
||||
previous_heading: str = heading
|
||||
level = heading_search.group(1)
|
||||
heading = heading_search.group(2)
|
||||
bodytext = ""
|
||||
@@ -495,12 +495,13 @@ class Orgnode(object):
|
||||
if self._priority:
|
||||
n = n + "[#" + self._priority + "] "
|
||||
n = n + self._heading
|
||||
n = "%-60s " % n # hack - tags will start in column 62
|
||||
closecolon = ""
|
||||
for t in self._tags:
|
||||
n = n + ":" + t
|
||||
closecolon = ":"
|
||||
n = n + closecolon
|
||||
if self._tags:
|
||||
n = "%-60s " % n # hack - tags will start in column 62
|
||||
closecolon = ""
|
||||
for t in self._tags:
|
||||
n = n + ":" + t
|
||||
closecolon = ":"
|
||||
n = n + closecolon
|
||||
n = n + "\n"
|
||||
|
||||
# Get body indentation from first line of body
|
||||
|
||||
@@ -32,8 +32,8 @@ class PdfToEntries(TextToEntries):
|
||||
deletion_file_names = None
|
||||
|
||||
# Extract Entries from specified Pdf files
|
||||
with timer("Parse entries from PDF files into dictionaries", logger):
|
||||
current_entries = PdfToEntries.convert_pdf_entries_to_maps(*PdfToEntries.extract_pdf_entries(files))
|
||||
with timer("Extract entries from specified PDF files", logger):
|
||||
current_entries = PdfToEntries.extract_pdf_entries(files)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
@@ -55,11 +55,11 @@ class PdfToEntries(TextToEntries):
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_pdf_entries(pdf_files):
|
||||
def extract_pdf_entries(pdf_files) -> List[Entry]:
|
||||
"""Extract entries by page from specified PDF files"""
|
||||
|
||||
entries = []
|
||||
entry_to_location_map = []
|
||||
entries: List[str] = []
|
||||
entry_to_location_map: List[Tuple[str, str]] = []
|
||||
for pdf_file in pdf_files:
|
||||
try:
|
||||
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
|
||||
@@ -83,7 +83,7 @@ class PdfToEntries(TextToEntries):
|
||||
if os.path.exists(f"{tmp_file}"):
|
||||
os.remove(f"{tmp_file}")
|
||||
|
||||
return entries, dict(entry_to_location_map)
|
||||
return PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
|
||||
|
||||
@staticmethod
|
||||
def convert_pdf_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
|
||||
@@ -106,8 +106,3 @@ class PdfToEntries(TextToEntries):
|
||||
logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries")
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_pdf_maps_to_jsonl(entries: List[Entry]):
|
||||
"Convert each PDF entry to JSON and collate as JSONL"
|
||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
||||
|
||||
@@ -42,8 +42,8 @@ class PlaintextToEntries(TextToEntries):
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
# Extract Entries from specified plaintext files
|
||||
with timer("Parse entries from plaintext files", logger):
|
||||
current_entries = PlaintextToEntries.convert_plaintext_entries_to_maps(files)
|
||||
with timer("Parse entries from specified Plaintext files", logger):
|
||||
current_entries = PlaintextToEntries.extract_plaintext_entries(files)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
@@ -74,7 +74,7 @@ class PlaintextToEntries(TextToEntries):
|
||||
return soup.get_text(strip=True, separator="\n")
|
||||
|
||||
@staticmethod
|
||||
def convert_plaintext_entries_to_maps(entry_to_file_map: dict) -> List[Entry]:
|
||||
def extract_plaintext_entries(entry_to_file_map: dict[str, str]) -> List[Entry]:
|
||||
"Convert each plaintext entries into a dictionary"
|
||||
entries = []
|
||||
for file, entry in entry_to_file_map.items():
|
||||
@@ -87,8 +87,3 @@ class PlaintextToEntries(TextToEntries):
|
||||
)
|
||||
)
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_entries_to_jsonl(entries: List[Entry]):
|
||||
"Convert each entry to JSON and collate as JSONL"
|
||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import repeat
|
||||
from typing import Any, Callable, List, Set, Tuple
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from tqdm import tqdm
|
||||
|
||||
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||
@@ -34,6 +36,27 @@ class TextToEntries(ABC):
|
||||
def hash_func(key: str) -> Callable:
|
||||
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def remove_long_words(text: str, max_word_length: int = 500) -> str:
|
||||
"Remove words longer than max_word_length from text."
|
||||
# Split the string by words, keeping the delimiters
|
||||
splits = re.split(r"(\s+)", text) + [""]
|
||||
words_with_delimiters = list(zip(splits[::2], splits[1::2]))
|
||||
|
||||
# Filter out long words while preserving delimiters in text
|
||||
filtered_text = [
|
||||
f"{word}{delimiter}"
|
||||
for word, delimiter in words_with_delimiters
|
||||
if not word.strip() or len(word.strip()) <= max_word_length
|
||||
]
|
||||
|
||||
return "".join(filtered_text)
|
||||
|
||||
@staticmethod
|
||||
def tokenizer(text: str) -> List[str]:
|
||||
"Tokenize text into words."
|
||||
return text.split()
|
||||
|
||||
@staticmethod
|
||||
def split_entries_by_max_tokens(
|
||||
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500
|
||||
@@ -44,24 +67,30 @@ class TextToEntries(ABC):
|
||||
if is_none_or_empty(entry.compiled):
|
||||
continue
|
||||
|
||||
# Split entry into words
|
||||
compiled_entry_words = [word for word in entry.compiled.split(" ") if word != ""]
|
||||
|
||||
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
|
||||
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length]
|
||||
# Split entry into chunks of max_tokens
|
||||
# Use chunking preference order: paragraphs > sentences > words > characters
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=max_tokens,
|
||||
separators=["\n\n", "\n", "!", "?", ".", " ", "\t", ""],
|
||||
keep_separator=True,
|
||||
length_function=lambda chunk: len(TextToEntries.tokenizer(chunk)),
|
||||
chunk_overlap=0,
|
||||
)
|
||||
chunked_entry_chunks = text_splitter.split_text(entry.compiled)
|
||||
corpus_id = uuid.uuid4()
|
||||
|
||||
# Split entry into chunks of max tokens
|
||||
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
|
||||
compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens]
|
||||
compiled_entry_chunk = " ".join(compiled_entry_words_chunk)
|
||||
|
||||
# Create heading prefixed entry from each chunk
|
||||
for chunk_index, compiled_entry_chunk in enumerate(chunked_entry_chunks):
|
||||
# Prepend heading to all other chunks, the first chunk already has heading from original entry
|
||||
if chunk_index > 0:
|
||||
if chunk_index > 0 and entry.heading:
|
||||
# Snip heading to avoid crossing max_tokens limit
|
||||
# Keep last 100 characters of heading as entry heading more important than filename
|
||||
snipped_heading = entry.heading[-100:]
|
||||
compiled_entry_chunk = f"{snipped_heading}.\n{compiled_entry_chunk}"
|
||||
# Prepend snipped heading
|
||||
compiled_entry_chunk = f"{snipped_heading}\n{compiled_entry_chunk}"
|
||||
|
||||
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
|
||||
compiled_entry_chunk = TextToEntries.remove_long_words(compiled_entry_chunk, max_word_length)
|
||||
|
||||
# Clean entry of unwanted characters like \0 character
|
||||
compiled_entry_chunk = TextToEntries.clean_field(compiled_entry_chunk)
|
||||
@@ -160,7 +189,7 @@ class TextToEntries(ABC):
|
||||
new_dates = []
|
||||
with timer("Indexed dates from added entries in", logger):
|
||||
for added_entry in added_entries:
|
||||
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.raw), repeat(added_entry))
|
||||
dates_in_entries = zip(self.date_filter.extract_dates(added_entry.compiled), repeat(added_entry))
|
||||
dates_to_create = [
|
||||
EntryDates(date=date, entry=added_entry)
|
||||
for date, added_entry in dates_in_entries
|
||||
@@ -244,11 +273,6 @@ class TextToEntries(ABC):
|
||||
|
||||
return entries_with_ids
|
||||
|
||||
@staticmethod
|
||||
def convert_text_maps_to_jsonl(entries: List[Entry]) -> str:
|
||||
# Convert each entry to JSON and write to JSONL file
|
||||
return "".join([f"{entry.to_json()}\n" for entry in entries])
|
||||
|
||||
@staticmethod
|
||||
def clean_field(field: str) -> str:
|
||||
return field.replace("\0", "") if not is_none_or_empty(field) else ""
|
||||
|
||||
@@ -489,7 +489,7 @@ async def chat(
|
||||
common: CommonQueryParams,
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
d: Optional[float] = 0.18,
|
||||
d: Optional[float] = 0.22,
|
||||
stream: Optional[bool] = False,
|
||||
title: Optional[str] = None,
|
||||
conversation_id: Optional[int] = None,
|
||||
|
||||
Reference in New Issue
Block a user