Split text entries by max tokens supported by ML models

### Background
There is a limit to the maximum input tokens (words) that an ML model can encode into an embedding vector.
For the models used for text search in khoj, a max token size of 256 words is appropriate [1](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1#:~:text=model%20was%20just%20trained%20on%20input%20text%20up%20to%20250%20word%20pieces),[2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2#:~:text=input%20text%20longer%20than%20256%20word%20pieces%20is%20truncated)

### Issue
Until now entries exceeding max token size would silently get truncated during embedding generation.
So the truncated portion of the entries would be ignored when matching queries with entries
This would degrade the quality of the results

### Fix
- e057c8e Add method to split entries by specified max tokens limit
- Split entries by max tokens while converting [Org](https://github.com/debanjum/khoj/commit/c79919b), [Markdown](https://github.com/debanjum/khoj/commit/f209e30) and [Beancount](https://github.com/debanjum/khoj/commit/17fa123) entries to JSONL
- b283650 Deduplicate results for user query by raw text before returning results

### Results
- The quality of the search results should improve
- Relevant, long entries should show up in results more often
This commit is contained in:
Debanjum
2022-12-26 18:23:43 +00:00
committed by GitHub
8 changed files with 102 additions and 3 deletions

View File

@@ -35,6 +35,12 @@ class BeancountToJsonl(TextToJsonl):
end = time.time()
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
# Split entries by max tokens supported by model
start = time.time()
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
end = time.time()
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries
start = time.time()
if not previous_entries:

View File

@@ -35,6 +35,12 @@ class MarkdownToJsonl(TextToJsonl):
end = time.time()
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
# Split entries by max tokens supported by model
start = time.time()
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
end = time.time()
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries
start = time.time()
if not previous_entries:

View File

@@ -41,7 +41,12 @@ class OrgToJsonl(TextToJsonl):
start = time.time()
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
end = time.time()
logger.debug(f"Convert OrgNodes into entry dictionaries: {end - start} seconds")
logger.debug(f"Convert OrgNodes into list of entries: {end - start} seconds")
start = time.time()
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
end = time.time()
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries
if not previous_entries:

View File

@@ -23,6 +23,19 @@ class TextToJsonl(ABC):
def hash_func(key: str) -> Callable:
return lambda entry: hashlib.md5(bytes(getattr(entry, key), encoding='utf-8')).hexdigest()
@staticmethod
def split_entries_by_max_tokens(entries: list[Entry], max_tokens: int=256) -> list[Entry]:
"Split entries if compiled entry length exceeds the max tokens supported by the ML model."
chunked_entries: list[Entry] = []
for entry in entries:
compiled_entry_words = entry.compiled.split()
for chunk_index in range(0, len(compiled_entry_words), max_tokens):
compiled_entry_words_chunk = compiled_entry_words[chunk_index:chunk_index + max_tokens]
compiled_entry_chunk = ' '.join(compiled_entry_words_chunk)
entry_chunk = Entry(compiled=compiled_entry_chunk, raw=entry.raw, file=entry.file)
chunked_entries.append(entry_chunk)
return chunked_entries
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
# Hash all current and previous entries to identify new entries
start = time.time()

View File

@@ -150,6 +150,17 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
end = time.time()
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
# Deduplicate entries by raw entry text before showing to users
# Compiled entries are split by max tokens supported by ML models.
# This can result in duplicate hits, entries shown to user.
start = time.time()
seen, original_hits_count = set(), len(hits)
hits = [hit for hit in hits
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
duplicate_hits = original_hits_count - len(hits)
end = time.time()
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
return hits, entries