Chunk text in preference order of para, sentence, word, character

- Previous simplistic chunking strategy of splitting text by space
  didn't capture notes with newlines, no spaces. For e.g in #620

- New strategy will try chunk the text at more natural points like
  paragraph, sentence, word first. If none of those work it'll split
  at character to fit within max token limit

- Drop long words while preserving original delimiters

Resolves #620
This commit is contained in:
Debanjum Singh Solanky
2024-01-29 05:03:29 +05:30
parent a627f56a64
commit 86575b2946
3 changed files with 46 additions and 17 deletions

View File

@@ -1,10 +1,12 @@
import hashlib import hashlib
import logging import logging
import re
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from itertools import repeat from itertools import repeat
from typing import Any, Callable, List, Set, Tuple from typing import Any, Callable, List, Set, Tuple
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tqdm import tqdm from tqdm import tqdm
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
@@ -34,6 +36,27 @@ class TextToEntries(ABC):
def hash_func(key: str) -> Callable: def hash_func(key: str) -> Callable:
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest() return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding="utf-8")).hexdigest()
@staticmethod
def remove_long_words(text: str, max_word_length: int = 500) -> str:
"Remove words longer than max_word_length from text."
# Split the string by words, keeping the delimiters
splits = re.split(r"(\s+)", text) + [""]
words_with_delimiters = list(zip(splits[::2], splits[1::2]))
# Filter out long words while preserving delimiters in text
filtered_text = [
f"{word}{delimiter}"
for word, delimiter in words_with_delimiters
if not word.strip() or len(word.strip()) <= max_word_length
]
return "".join(filtered_text)
@staticmethod
def tokenizer(text: str) -> List[str]:
"Tokenize text into words."
return text.split()
@staticmethod @staticmethod
def split_entries_by_max_tokens( def split_entries_by_max_tokens(
entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500 entries: List[Entry], max_tokens: int = 256, max_word_length: int = 500
@@ -44,24 +67,30 @@ class TextToEntries(ABC):
if is_none_or_empty(entry.compiled): if is_none_or_empty(entry.compiled):
continue continue
# Split entry into words # Split entry into chunks of max_tokens
compiled_entry_words = [word for word in entry.compiled.split(" ") if word != ""] # Use chunking preference order: paragraphs > sentences > words > characters
text_splitter = RecursiveCharacterTextSplitter(
# Drop long words instead of having entry truncated to maintain quality of entry processed by models chunk_size=max_tokens,
compiled_entry_words = [word for word in compiled_entry_words if len(word) <= max_word_length] separators=["\n\n", "\n", "!", "?", ".", " ", "\t", ""],
keep_separator=True,
length_function=lambda chunk: len(TextToEntries.tokenizer(chunk)),
chunk_overlap=0,
)
chunked_entry_chunks = text_splitter.split_text(entry.compiled)
corpus_id = uuid.uuid4() corpus_id = uuid.uuid4()
# Split entry into chunks of max tokens # Create heading prefixed entry from each chunk
for chunk_index in range(0, len(compiled_entry_words), max_tokens): for chunk_index, compiled_entry_chunk in enumerate(chunked_entry_chunks):
compiled_entry_words_chunk = compiled_entry_words[chunk_index : chunk_index + max_tokens]
compiled_entry_chunk = " ".join(compiled_entry_words_chunk)
# Prepend heading to all other chunks, the first chunk already has heading from original entry # Prepend heading to all other chunks, the first chunk already has heading from original entry
if chunk_index > 0: if chunk_index > 0 and entry.heading:
# Snip heading to avoid crossing max_tokens limit # Snip heading to avoid crossing max_tokens limit
# Keep last 100 characters of heading as entry heading more important than filename # Keep last 100 characters of heading as entry heading more important than filename
snipped_heading = entry.heading[-100:] snipped_heading = entry.heading[-100:]
compiled_entry_chunk = f"{snipped_heading}.\n{compiled_entry_chunk}" # Prepend snipped heading
compiled_entry_chunk = f"{snipped_heading}\n{compiled_entry_chunk}"
# Drop long words instead of having entry truncated to maintain quality of entry processed by models
compiled_entry_chunk = TextToEntries.remove_long_words(compiled_entry_chunk, max_word_length)
# Clean entry of unwanted characters like \0 character # Clean entry of unwanted characters like \0 character
compiled_entry_chunk = TextToEntries.clean_field(compiled_entry_chunk) compiled_entry_chunk = TextToEntries.clean_field(compiled_entry_chunk)

View File

@@ -54,12 +54,12 @@ def test_entry_split_when_exceeds_max_words():
# Extract Entries from specified Org files # Extract Entries from specified Org files
entries = OrgToEntries.extract_org_entries(org_files=data) entries = OrgToEntries.extract_org_entries(org_files=data)
# Split each entry from specified Org files by max words # Split each entry from specified Org files by max tokens
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=4) entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=6)
# Assert # Assert
assert len(entries) == 2 assert len(entries) == 2
# Ensure compiled entries split by max_words start with entry heading (for search context) # Ensure compiled entries split by max tokens start with entry heading (for search context)
assert all([entry.compiled.startswith(expected_heading) for entry in entries]) assert all([entry.compiled.startswith(expected_heading) for entry in entries])

View File

@@ -192,7 +192,7 @@ def test_entry_chunking_by_max_tokens(org_config_with_only_new_file: LocalOrgCon
# Assert # Assert
assert ( assert (
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens" ), "new entry not split by max tokens"
@@ -250,7 +250,7 @@ conda activate khoj
# Assert # Assert
assert ( assert (
"Deleted 0 entries. Created 2 new entries for user " in caplog.records[-1].message "Deleted 0 entries. Created 3 new entries for user " in caplog.records[-1].message
), "new entry not split by max tokens" ), "new entry not split by max tokens"