Rename TextEmbeddings to TextEntries for improved readability

Improves readability as name has closer match to underlying
constructs
This commit is contained in:
Debanjum Singh Solanky
2023-10-31 18:55:59 -07:00
parent bcbee05a9e
commit 87e6b1eab9
9 changed files with 24 additions and 24 deletions

View File

@@ -12,14 +12,14 @@ from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEntries
from database.models import Entry as DbEntry, GithubConfig, KhojUser from database.models import Entry as DbEntry, GithubConfig, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GithubToJsonl(TextEmbeddings): class GithubToJsonl(TextEntries):
def __init__(self, config: GithubConfig): def __init__(self, config: GithubConfig):
super().__init__(config) super().__init__(config)
raw_repos = config.githubrepoconfig.all() raw_repos = config.githubrepoconfig.all()
@@ -94,7 +94,7 @@ class GithubToJsonl(TextEmbeddings):
current_entries += issue_entries current_entries += issue_entries
with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger): with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
current_entries = TextEmbeddings.split_entries_by_max_tokens(current_entries, max_tokens=256) current_entries = TextEntries.split_entries_by_max_tokens(current_entries, max_tokens=256)
return current_entries return current_entries

View File

@@ -6,7 +6,7 @@ from pathlib import Path
from typing import Tuple, List from typing import Tuple, List
# Internal Packages # Internal Packages
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEntries
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.constants import empty_escape_sequences from khoj.utils.constants import empty_escape_sequences
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
@@ -16,7 +16,7 @@ from database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MarkdownToJsonl(TextEmbeddings): class MarkdownToJsonl(TextEntries):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@@ -8,7 +8,7 @@ import requests
# Internal Packages # Internal Packages
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, NotionContentConfig from khoj.utils.rawconfig import Entry, NotionContentConfig
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEntries
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from database.models import Entry as DbEntry, KhojUser, NotionConfig from database.models import Entry as DbEntry, KhojUser, NotionConfig
@@ -50,7 +50,7 @@ class NotionBlockType(Enum):
CALLOUT = "callout" CALLOUT = "callout"
class NotionToJsonl(TextEmbeddings): class NotionToJsonl(TextEntries):
def __init__(self, config: NotionConfig): def __init__(self, config: NotionConfig):
super().__init__(config) super().__init__(config)
self.config = NotionContentConfig( self.config = NotionContentConfig(

View File

@@ -5,7 +5,7 @@ from typing import Iterable, List, Tuple
# Internal Packages # Internal Packages
from khoj.processor.org_mode import orgnode from khoj.processor.org_mode import orgnode
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEntries
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from khoj.utils import state from khoj.utils import state
@@ -15,7 +15,7 @@ from database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OrgToJsonl(TextEmbeddings): class OrgToJsonl(TextEntries):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@@ -8,7 +8,7 @@ import base64
from langchain.document_loaders import PyMuPDFLoader from langchain.document_loaders import PyMuPDFLoader
# Internal Packages # Internal Packages
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEntries
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from database.models import Entry as DbEntry, KhojUser from database.models import Entry as DbEntry, KhojUser
@@ -17,7 +17,7 @@ from database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PdfToJsonl(TextEmbeddings): class PdfToJsonl(TextEntries):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@@ -6,7 +6,7 @@ from bs4 import BeautifulSoup
# Internal Packages # Internal Packages
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEntries
from khoj.utils.helpers import timer from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from database.models import Entry as DbEntry, KhojUser from database.models import Entry as DbEntry, KhojUser
@@ -15,7 +15,7 @@ from database.models import Entry as DbEntry, KhojUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PlaintextToJsonl(TextEmbeddings): class PlaintextToJsonl(TextEntries):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@@ -19,7 +19,7 @@ from database.adapters import EntryAdapters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TextEmbeddings(ABC): class TextEntries(ABC):
def __init__(self, config: Any = None): def __init__(self, config: Any = None):
self.embeddings_model = EmbeddingsModel() self.embeddings_model = EmbeddingsModel()
self.config = config self.config = config
@@ -85,10 +85,10 @@ class TextEmbeddings(ABC):
): ):
with timer("Construct current entry hashes", logger): with timer("Construct current entry hashes", logger):
hashes_by_file = dict[str, set[str]]() hashes_by_file = dict[str, set[str]]()
current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries)) current_entry_hashes = list(map(TextEntries.hash_func(key), current_entries))
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
for entry in tqdm(current_entries, desc="Hashing Entries"): for entry in tqdm(current_entries, desc="Hashing Entries"):
hashes_by_file.setdefault(entry.file, set()).add(TextEmbeddings.hash_func(key)(entry)) hashes_by_file.setdefault(entry.file, set()).add(TextEntries.hash_func(key)(entry))
num_deleted_embeddings = 0 num_deleted_embeddings = 0
with timer("Preparing dataset for regeneration", logger): with timer("Preparing dataset for regeneration", logger):
@@ -180,11 +180,11 @@ class TextEmbeddings(ABC):
): ):
# Hash all current and previous entries to identify new entries # Hash all current and previous entries to identify new entries
with timer("Hash previous, current entries", logger): with timer("Hash previous, current entries", logger):
current_entry_hashes = list(map(TextEmbeddings.hash_func(key), current_entries)) current_entry_hashes = list(map(TextEntries.hash_func(key), current_entries))
previous_entry_hashes = list(map(TextEmbeddings.hash_func(key), previous_entries)) previous_entry_hashes = list(map(TextEntries.hash_func(key), previous_entries))
if deletion_filenames is not None: if deletion_filenames is not None:
deletion_entries = [entry for entry in previous_entries if entry.file in deletion_filenames] deletion_entries = [entry for entry in previous_entries if entry.file in deletion_filenames]
deletion_entry_hashes = list(map(TextEmbeddings.hash_func(key), deletion_entries)) deletion_entry_hashes = list(map(TextEntries.hash_func(key), deletion_entries))
else: else:
deletion_entry_hashes = [] deletion_entry_hashes = []

View File

@@ -18,7 +18,7 @@ from khoj.utils.models import BaseEncoder
from khoj.utils.state import SearchType from khoj.utils.state import SearchType
from khoj.utils.rawconfig import SearchResponse, Entry from khoj.utils.rawconfig import SearchResponse, Entry
from khoj.utils.jsonl import load_jsonl from khoj.utils.jsonl import load_jsonl
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEntries
from database.adapters import EntryAdapters from database.adapters import EntryAdapters
from database.models import KhojUser, Entry as DbEntry from database.models import KhojUser, Entry as DbEntry
@@ -188,7 +188,7 @@ def rerank_and_sort_results(hits, query):
def setup( def setup(
text_to_jsonl: Type[TextEmbeddings], text_to_jsonl: Type[TextEntries],
files: dict[str, str], files: dict[str, str],
regenerate: bool, regenerate: bool,
full_corpus: bool = True, full_corpus: bool = True,

View File

@@ -4,7 +4,7 @@ import os
# Internal Packages # Internal Packages
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.text_to_jsonl import TextEmbeddings from khoj.processor.text_to_jsonl import TextEntries
from khoj.utils.helpers import is_none_or_empty from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import Entry from khoj.utils.rawconfig import Entry
from khoj.utils.fs_syncer import get_org_files from khoj.utils.fs_syncer import get_org_files
@@ -63,7 +63,7 @@ def test_entry_split_when_exceeds_max_words(tmp_path):
# Split each entry from specified Org files by max words # Split each entry from specified Org files by max words
jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl( jsonl_string = OrgToJsonl.convert_org_entries_to_jsonl(
TextEmbeddings.split_entries_by_max_tokens( TextEntries.split_entries_by_max_tokens(
OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4 OrgToJsonl.convert_org_nodes_to_entries(entries, entry_to_file_map), max_tokens=4
) )
) )
@@ -86,7 +86,7 @@ def test_entry_split_drops_large_words():
# Act # Act
# Split entry by max words and drop words larger than max word length # Split entry by max words and drop words larger than max word length
processed_entry = TextEmbeddings.split_entries_by_max_tokens([entry], max_word_length=5)[0] processed_entry = TextEntries.split_entries_by_max_tokens([entry], max_word_length=5)[0]
# Assert # Assert
# "Heading" dropped from compiled version because its over the set max word limit # "Heading" dropped from compiled version because its over the set max word limit