Dedupe code by using single func to process an org file into entries

Add type hints to orgnode and org-to-entries packages
This commit is contained in:
Debanjum Singh Solanky
2024-02-11 00:34:04 +05:30
parent db2581459f
commit 44eab74888
3 changed files with 42 additions and 39 deletions

View File

@@ -1,10 +1,11 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Iterable, List, Tuple from typing import Dict, List, Tuple
from khoj.database.models import Entry as DbEntry from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
from khoj.processor.content.org_mode import orgnode from khoj.processor.content.org_mode import orgnode
from khoj.processor.content.org_mode.orgnode import Orgnode
from khoj.processor.content.text_to_entries import TextToEntries from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
@@ -51,7 +52,7 @@ class OrgToEntries(TextToEntries):
return num_new_embeddings, num_deleted_embeddings return num_new_embeddings, num_deleted_embeddings
@staticmethod @staticmethod
def extract_org_entries(org_files: dict[str, str], index_heading_entries: bool = False): def extract_org_entries(org_files: dict[str, str], index_heading_entries: bool = False) -> List[Entry]:
"Extract entries from specified Org files" "Extract entries from specified Org files"
with timer("Parse entries from org files into OrgNode objects", logger): with timer("Parse entries from org files into OrgNode objects", logger):
entry_nodes, file_to_entries = OrgToEntries.extract_org_nodes(org_files) entry_nodes, file_to_entries = OrgToEntries.extract_org_nodes(org_files)
@@ -60,35 +61,35 @@ class OrgToEntries(TextToEntries):
return OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries) return OrgToEntries.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
@staticmethod @staticmethod
def extract_org_nodes(org_files: dict[str, str]): def extract_org_nodes(org_files: dict[str, str]) -> Tuple[List[Orgnode], Dict[Orgnode, str]]:
"Extract org nodes from specified org files" "Extract org nodes from specified org files"
entry_nodes = [] entry_nodes: List[Orgnode] = []
entry_to_file_map: List[Tuple[orgnode.Orgnode, str]] = [] entry_to_file_map: List[Tuple[Orgnode, str]] = []
for org_file in org_files: for org_file in org_files:
filename = org_file org_content = org_files[org_file]
file = org_files[org_file] entry_nodes, entry_to_file_map = OrgToEntries.process_single_org_file(
try: org_content, org_file, entry_nodes, entry_to_file_map
org_file_entries = orgnode.makelist(file, filename) )
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
entry_nodes.extend(org_file_entries)
except Exception as e:
logger.warning(f"Unable to process file: {org_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
return entry_nodes, dict(entry_to_file_map) return entry_nodes, dict(entry_to_file_map)
@staticmethod @staticmethod
def process_single_org_file(org_content: str, org_file: str, entries: List, entry_to_file_map: List): def process_single_org_file(
org_content: str,
org_file: str,
entries: List[Orgnode],
entry_to_file_map: List[Tuple[Orgnode, str]],
) -> Tuple[List[Orgnode], List[Tuple[Orgnode, str]]]:
# Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer. # Process single org file. The org parser assumes that the file is a single org file and reads it from a buffer.
# We'll split the raw content of this file by new line to mimic the same behavior. # We'll split the raw content of this file by new line to mimic the same behavior.
try: try:
org_file_entries = orgnode.makelist(org_content, org_file) org_file_entries = orgnode.makelist(org_content, org_file)
entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries)) entry_to_file_map += zip(org_file_entries, [org_file] * len(org_file_entries))
entries.extend(org_file_entries) entries.extend(org_file_entries)
return entries, entry_to_file_map
except Exception as e: except Exception as e:
logger.error(f"Error processing file: {org_file} with error: {e}", exc_info=True) logger.error(f"Unable to process file: {org_file}. Skipped indexing it.\nError; {e}", exc_info=True)
return entries, entry_to_file_map
return entries, entry_to_file_map
@staticmethod @staticmethod
def convert_org_nodes_to_entries( def convert_org_nodes_to_entries(

View File

@@ -37,7 +37,7 @@ import datetime
import re import re
from os.path import relpath from os.path import relpath
from pathlib import Path from pathlib import Path
from typing import List from typing import Dict, List, Tuple
indent_regex = re.compile(r"^ *") indent_regex = re.compile(r"^ *")
@@ -58,7 +58,7 @@ def makelist_with_filepath(filename):
return makelist(f, filename) return makelist(f, filename)
def makelist(file, filename): def makelist(file, filename) -> List["Orgnode"]:
""" """
Read an org-mode file and return a list of Orgnode objects Read an org-mode file and return a list of Orgnode objects
created from this file. created from this file.
@@ -80,16 +80,16 @@ def makelist(file, filename):
} # populated from #+SEQ_TODO line } # populated from #+SEQ_TODO line
level = "" level = ""
heading = "" heading = ""
ancestor_headings = [] ancestor_headings: List[str] = []
bodytext = "" bodytext = ""
introtext = "" introtext = ""
tags = list() # set of all tags in headline tags: List[str] = list() # set of all tags in headline
closed_date = "" closed_date: datetime.date = None
sched_date = "" sched_date: datetime.date = None
deadline_date = "" deadline_date: datetime.date = None
logbook = list() logbook: List[Tuple[datetime.datetime, datetime.datetime]] = list()
nodelist: List[Orgnode] = list() nodelist: List[Orgnode] = list()
property_map = dict() property_map: Dict[str, str] = dict()
in_properties_drawer = False in_properties_drawer = False
in_logbook_drawer = False in_logbook_drawer = False
file_title = f"{filename}" file_title = f"{filename}"
@@ -102,13 +102,13 @@ def makelist(file, filename):
thisNode = Orgnode(level, heading, bodytext, tags, ancestor_headings) thisNode = Orgnode(level, heading, bodytext, tags, ancestor_headings)
if closed_date: if closed_date:
thisNode.closed = closed_date thisNode.closed = closed_date
closed_date = "" closed_date = None
if sched_date: if sched_date:
thisNode.scheduled = sched_date thisNode.scheduled = sched_date
sched_date = "" sched_date = None
if deadline_date: if deadline_date:
thisNode.deadline = deadline_date thisNode.deadline = deadline_date
deadline_date = "" deadline_date = None
if logbook: if logbook:
thisNode.logbook = logbook thisNode.logbook = logbook
logbook = list() logbook = list()
@@ -116,7 +116,7 @@ def makelist(file, filename):
nodelist.append(thisNode) nodelist.append(thisNode)
property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"} property_map = {"LINE": f"file:{normalize_filename(filename)}::{ctr}"}
previous_level = level previous_level = level
previous_heading = heading previous_heading: str = heading
level = heading_search.group(1) level = heading_search.group(1)
heading = heading_search.group(2) heading = heading_search.group(2)
bodytext = "" bodytext = ""

View File

@@ -37,8 +37,8 @@ def test_configure_heading_entry_to_jsonl(tmp_path):
assert is_none_or_empty(entries) assert is_none_or_empty(entries)
def test_entry_split_when_exceeds_max_words(): def test_entry_split_when_exceeds_max_tokens():
"Ensure entries with compiled words exceeding max_words are split." "Ensure entries with compiled words exceeding max_tokens are split."
# Arrange # Arrange
tmp_path = "/tmp/test.org" tmp_path = "/tmp/test.org"
entry = f"""*** Heading entry = f"""*** Heading
@@ -81,7 +81,7 @@ def test_entry_split_drops_large_words():
assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 1 assert len(processed_entry.compiled.split()) == len(entry_text.split()) - 1
def test_entry_with_body_to_jsonl(tmp_path): def test_entry_with_body_to_entry(tmp_path):
"Ensure entries with valid body text are loaded." "Ensure entries with valid body text are loaded."
# Arrange # Arrange
entry = f"""*** Heading entry = f"""*** Heading
@@ -97,13 +97,13 @@ def test_entry_with_body_to_jsonl(tmp_path):
# Act # Act
# Extract Entries from specified Org files # Extract Entries from specified Org files
entries = OrgToEntries.extract_org_entries(org_files=data) entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=3)
# Assert # Assert
assert len(entries) == 1 assert len(entries) == 1
def test_file_with_entry_after_intro_text_to_jsonl(tmp_path): def test_file_with_entry_after_intro_text_to_entry(tmp_path):
"Ensure intro text before any headings is indexed." "Ensure intro text before any headings is indexed."
# Arrange # Arrange
entry = f""" entry = f"""
@@ -188,7 +188,8 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Arrange # Arrange
entry = f""" entry = f"""
* Heading 1 * Heading 1
** Heading 2 ** Sub-Heading 1.1
* Heading 2
""" """
data = { data = {
f"{tmp_path}": entry, f"{tmp_path}": entry,
@@ -199,9 +200,10 @@ def test_extract_entries_with_different_level_headings(tmp_path):
entries = OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=True) entries = OrgToEntries.extract_org_entries(org_files=data, index_heading_entries=True)
# Assert # Assert
assert len(entries) == 2 assert len(entries) == 3
assert f"{entries[0].raw}".startswith("* Heading 1") assert f"{entries[0].raw}".startswith("* Heading 1")
assert f"{entries[1].raw}".startswith("** Heading 2") assert f"{entries[1].raw}".startswith("** Sub-Heading 1.1")
assert f"{entries[2].raw}".startswith("* Heading 2")
# Helper Functions # Helper Functions