From 7e9298f31576ebb2b18f85369c8b0e8d88a2f0c7 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 15 Sep 2022 23:34:43 +0300 Subject: [PATCH] Use new Text Entry class to track text entries in Intermediate Format - Context - The app maintains all text content in a standard, intermediate format - The intermediate format was loaded, passed around as a dictionary for easier, faster updates to the intermediate format schema initially - The intermediate format is reasonably stable now, given it's usage by all 3 text content types currently implemented - Changes - Concretize text entries into `Entries' class instead of using dictionaries - Code is updated to load, pass around entries as `Entries' objects instead of as dictionaries - `text_search' and `text_to_jsonl' methods are annotated with type hints for the new `Entries' type - Code and Tests referencing entries are updated to use class style access patterns instead of the previous dictionary access patterns - Move `mark_entries_for_update' method into `TextToJsonl' base class - This is a more natural location for the method as it is only (to be) used by `text_to_jsonl' classes - Avoid circular reference issues on importing `Entries' class --- src/processor/ledger/beancount_to_jsonl.py | 26 ++++----- src/processor/markdown/markdown_to_jsonl.py | 24 ++++---- src/processor/org_mode/org_to_jsonl.py | 65 ++++++++++----------- src/processor/text_to_jsonl.py | 46 ++++++++++++++- src/search_filter/date_filter.py | 2 +- src/search_filter/file_filter.py | 2 +- src/search_filter/word_filter.py | 2 +- src/search_type/text_search.py | 26 ++++----- src/utils/helpers.py | 37 ------------ src/utils/rawconfig.py | 27 ++++++++- tests/test_date_filter.py | 9 +-- tests/test_file_filter.py | 12 ++-- tests/test_image_search.py | 2 +- tests/test_text_search.py | 2 +- tests/test_word_filter.py | 10 ++-- 15 files changed, 161 insertions(+), 131 deletions(-) diff --git a/src/processor/ledger/beancount_to_jsonl.py b/src/processor/ledger/beancount_to_jsonl.py index d54b7e1b..ccad97da 100644 --- a/src/processor/ledger/beancount_to_jsonl.py +++ b/src/processor/ledger/beancount_to_jsonl.py @@ -1,5 +1,4 @@ # Standard Packages -import json import glob import re import logging @@ -7,9 +6,10 @@ import time # Internal Packages from src.processor.text_to_jsonl import TextToJsonl -from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update +from src.utils.helpers import get_absolute_path, is_none_or_empty from src.utils.constants import empty_escape_sequences from src.utils.jsonl import dump_jsonl, compress_jsonl_data +from src.utils.rawconfig import Entry logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ class BeancountToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) end = time.time() logger.debug(f"Identify new or updated transaction: {end - start} seconds") @@ -111,17 +111,17 @@ class BeancountToJsonl(TextToJsonl): return entries, dict(transaction_to_file_map) @staticmethod - def convert_transactions_to_maps(entries: list[str], transaction_to_file_map) -> list[dict]: - "Convert each Beancount transaction into a dictionary" - entry_maps = [] - for entry in entries: - entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{transaction_to_file_map[entry]}'}) + def convert_transactions_to_maps(parsed_entries: list[str], transaction_to_file_map) -> list[Entry]: + "Convert each parsed Beancount transaction into a Entry" + entries = [] + for parsed_entry in parsed_entries: + entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{transaction_to_file_map[parsed_entry]}')) - logger.info(f"Converted {len(entries)} transactions to dictionaries") + logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries") - return entry_maps + return entries @staticmethod - def convert_transaction_maps_to_jsonl(entries: list[dict]) -> str: - "Convert each Beancount transaction dictionary to JSON and collate as JSONL" - return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) + def convert_transaction_maps_to_jsonl(entries: list[Entry]) -> str: + "Convert each Beancount transaction entry to JSON and collate as JSONL" + return ''.join([f'{entry.to_json()}\n' for entry in entries]) diff --git a/src/processor/markdown/markdown_to_jsonl.py b/src/processor/markdown/markdown_to_jsonl.py index 48fbbdf9..5c4d660d 100644 --- a/src/processor/markdown/markdown_to_jsonl.py +++ b/src/processor/markdown/markdown_to_jsonl.py @@ -1,5 +1,4 @@ # Standard Packages -import json import glob import re import logging @@ -7,9 +6,10 @@ import time # Internal Packages from src.processor.text_to_jsonl import TextToJsonl -from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update +from src.utils.helpers import get_absolute_path, is_none_or_empty from src.utils.constants import empty_escape_sequences from src.utils.jsonl import dump_jsonl, compress_jsonl_data +from src.utils.rawconfig import Entry logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ class MarkdownToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) end = time.time() logger.debug(f"Identify new or updated entries: {end - start} seconds") @@ -110,17 +110,17 @@ class MarkdownToJsonl(TextToJsonl): return entries, dict(entry_to_file_map) @staticmethod - def convert_markdown_entries_to_maps(entries: list[str], entry_to_file_map) -> list[dict]: + def convert_markdown_entries_to_maps(parsed_entries: list[str], entry_to_file_map) -> list[Entry]: "Convert each Markdown entries into a dictionary" - entry_maps = [] - for entry in entries: - entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{entry_to_file_map[entry]}'}) + entries = [] + for parsed_entry in parsed_entries: + entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{entry_to_file_map[parsed_entry]}')) - logger.info(f"Converted {len(entries)} markdown entries to dictionaries") + logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries") - return entry_maps + return entries @staticmethod - def convert_markdown_maps_to_jsonl(entries): - "Convert each Markdown entries to JSON and collate as JSONL" - return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) + def convert_markdown_maps_to_jsonl(entries: list[Entry]): + "Convert each Markdown entry to JSON and collate as JSONL" + return ''.join([f'{entry.to_json()}\n' for entry in entries]) diff --git a/src/processor/org_mode/org_to_jsonl.py b/src/processor/org_mode/org_to_jsonl.py index c4c18ce9..52441a99 100644 --- a/src/processor/org_mode/org_to_jsonl.py +++ b/src/processor/org_mode/org_to_jsonl.py @@ -1,5 +1,4 @@ # Standard Packages -import json import glob import logging import time @@ -8,8 +7,9 @@ from typing import Iterable # Internal Packages from src.processor.org_mode import orgnode from src.processor.text_to_jsonl import TextToJsonl -from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update +from src.utils.helpers import get_absolute_path, is_none_or_empty from src.utils.jsonl import dump_jsonl, compress_jsonl_data +from src.utils.rawconfig import Entry from src.utils import state @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) class OrgToJsonl(TextToJsonl): # Define Functions - def process(self, previous_entries=None): + def process(self, previous_entries: list[Entry]=None): # Extract required fields from config org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl index_heading_entries = self.config.index_heading_entries @@ -47,7 +47,7 @@ class OrgToJsonl(TextToJsonl): if not previous_entries: entries_with_ids = list(enumerate(current_entries)) else: - entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) + entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) # Process Each Entry from All Notes Files start = time.time() @@ -104,51 +104,48 @@ class OrgToJsonl(TextToJsonl): return entries, dict(entry_to_file_map) @staticmethod - def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[dict]: - "Convert Org-Mode entries into list of dictionary" - entry_maps = [] - for entry in entries: - entry_dict = dict() - - if not entry.hasBody and not index_heading_entries: + def convert_org_nodes_to_entries(parsed_entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[Entry]: + "Convert Org-Mode nodes into list of Entry objects" + entries: list[Entry] = [] + for parsed_entry in parsed_entries: + if not parsed_entry.hasBody and not index_heading_entries: # Ignore title notes i.e notes with just headings and empty body continue - entry_dict["compiled"] = f'{entry.heading}.' + compiled = f'{parsed_entry.heading}.' if state.verbose > 2: - logger.debug(f"Title: {entry.heading}") + logger.debug(f"Title: {parsed_entry.heading}") - if entry.tags: - tags_str = " ".join(entry.tags) - entry_dict["compiled"] += f'\t {tags_str}.' + if parsed_entry.tags: + tags_str = " ".join(parsed_entry.tags) + compiled += f'\t {tags_str}.' if state.verbose > 2: logger.debug(f"Tags: {tags_str}") - if entry.closed: - entry_dict["compiled"] += f'\n Closed on {entry.closed.strftime("%Y-%m-%d")}.' + if parsed_entry.closed: + compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.' if state.verbose > 2: - logger.debug(f'Closed: {entry.closed.strftime("%Y-%m-%d")}') + logger.debug(f'Closed: {parsed_entry.closed.strftime("%Y-%m-%d")}') - if entry.scheduled: - entry_dict["compiled"] += f'\n Scheduled for {entry.scheduled.strftime("%Y-%m-%d")}.' + if parsed_entry.scheduled: + compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.' if state.verbose > 2: - logger.debug(f'Scheduled: {entry.scheduled.strftime("%Y-%m-%d")}') + logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}') - if entry.hasBody: - entry_dict["compiled"] += f'\n {entry.body}' + if parsed_entry.hasBody: + compiled += f'\n {parsed_entry.body}' if state.verbose > 2: - logger.debug(f"Body: {entry.body}") + logger.debug(f"Body: {parsed_entry.body}") - if entry_dict: - entry_dict["raw"] = f'{entry}' - entry_dict["file"] = f'{entry_to_file_map[entry]}' + if compiled: + entries += [Entry( + compiled=compiled, + raw=f'{parsed_entry}', + file=f'{entry_to_file_map[parsed_entry]}')] - # Convert Dictionary to JSON and Append to JSONL string - entry_maps.append(entry_dict) - - return entry_maps + return entries @staticmethod - def convert_org_entries_to_jsonl(entries: Iterable[dict]) -> str: + def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str: "Convert each Org-Mode entry to JSON and collate as JSONL" - return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) + return ''.join([f'{entry_dict.to_json()}\n' for entry_dict in entries]) diff --git a/src/processor/text_to_jsonl.py b/src/processor/text_to_jsonl.py index e59c5fb1..a8153f52 100644 --- a/src/processor/text_to_jsonl.py +++ b/src/processor/text_to_jsonl.py @@ -1,9 +1,14 @@ # Standard Packages from abc import ABC, abstractmethod -from typing import Iterable +import hashlib +import time +import logging # Internal Packages -from src.utils.rawconfig import TextContentConfig +from src.utils.rawconfig import Entry, TextContentConfig + + +logger = logging.getLogger(__name__) class TextToJsonl(ABC): @@ -11,4 +16,39 @@ class TextToJsonl(ABC): self.config = config @abstractmethod - def process(self, previous_entries: Iterable[tuple[int, dict]]=None) -> list[tuple[int, dict]]: ... + def process(self, previous_entries: list[Entry]=None) -> list[tuple[int, Entry]]: ... + + def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]: + # Hash all current and previous entries to identify new entries + start = time.time() + current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(getattr(e, key), encoding='utf-8')).hexdigest(), current_entries)) + previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(getattr(e, key), encoding='utf-8')).hexdigest(), previous_entries)) + end = time.time() + logger.debug(f"Hash previous, current entries: {end - start} seconds") + + start = time.time() + hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) + hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries)) + + # All entries that did not exist in the previous set are to be added + new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes) + # All entries that exist in both current and previous sets are kept + existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes) + + # Mark new entries with -1 id to flag for later embeddings generation + new_entries = [ + (-1, hash_to_current_entries[entry_hash]) + for entry_hash in new_entry_hashes + ] + # Set id of existing entries to their previous ids to reuse their existing encoded embeddings + existing_entries = [ + (previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash]) + for entry_hash in existing_entry_hashes + ] + + existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0]) + entries_with_ids = existing_entries_sorted + new_entries + end = time.time() + logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds") + + return entries_with_ids \ No newline at end of file diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index 22a66068..00b829ac 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -37,7 +37,7 @@ class DateFilter(BaseFilter): start = time.time() for id, entry in enumerate(entries): # Extract dates from entry - for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): + for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)): # Convert date string in entry to unix timestamp try: date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 41f80274..84b520c0 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -24,7 +24,7 @@ class FileFilter(BaseFilter): def load(self, entries, *args, **kwargs): start = time.time() for id, entry in enumerate(entries): - self.file_to_entry_map[entry[self.entry_key]].add(id) + self.file_to_entry_map[getattr(entry, self.entry_key)].add(id) end = time.time() logger.debug(f"Created file filter index: {end - start} seconds") diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py index e040ceee..ff9f9ee5 100644 --- a/src/search_filter/word_filter.py +++ b/src/search_filter/word_filter.py @@ -29,7 +29,7 @@ class WordFilter(BaseFilter): entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' # Create map of words to entries they exist in for entry_index, entry in enumerate(entries): - for word in re.split(entry_splitter, entry[self.entry_key].lower()): + for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()): if word == '': continue self.word_to_entry_index[word].add(entry_index) diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 009f39b9..8b29c517 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -13,7 +13,7 @@ from src.search_filter.base_filter import BaseFilter from src.utils import state from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model from src.utils.config import TextSearchModel -from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig +from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry from src.utils.jsonl import load_jsonl @@ -50,12 +50,12 @@ def initialize_model(search_config: TextSearchConfig): return bi_encoder, cross_encoder, top_k -def extract_entries(jsonl_file): +def extract_entries(jsonl_file) -> list[Entry]: "Load entries from compressed jsonl" - return load_jsonl(jsonl_file) + return list(map(Entry.from_dict, load_jsonl(jsonl_file))) -def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False): +def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" new_entries = [] # Load pre-computed embeddings from file if exists and update them if required @@ -64,15 +64,15 @@ def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate logger.info(f"Loaded embeddings from {embeddings_file}") # Encode any new entries in the corpus and update corpus embeddings - new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None] + new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1] if new_entries: new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) - existing_entry_ids = [id for id, _ in entries_with_ids if id is not None] + existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor() corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) # Else compute the corpus embeddings from scratch else: - new_entries = [entry['compiled'] for _, entry in entries_with_ids] + new_entries = [entry.compiled for _, entry in entries_with_ids] corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) # Save regenerated or updated embeddings to file @@ -133,7 +133,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): # Score all retrieved entries using the cross-encoder if rank_results: start = time.time() - cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits] + cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits] cross_scores = model.cross_encoder.predict(cross_inp) end = time.time() logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}") @@ -153,7 +153,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): return hits, entries -def render_results(hits, entries, count=5, display_biencoder_results=False): +def render_results(hits, entries: list[Entry], count=5, display_biencoder_results=False): "Render the Results returned by Search for the Query" if display_biencoder_results: # Output of top hits from bi-encoder @@ -161,20 +161,20 @@ def render_results(hits, entries, count=5, display_biencoder_results=False): print(f"Top-{count} Bi-Encoder Retrieval hits") hits = sorted(hits, key=lambda x: x['score'], reverse=True) for hit in hits[0:count]: - print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}") + print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']].compiled}") # Output of top hits from re-ranker print("\n-------------------------\n") print(f"Top-{count} Cross-Encoder Re-ranker hits") hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) for hit in hits[0:count]: - print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}") + print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']].compiled}") -def collate_results(hits, entries, count=5) -> list[SearchResponse]: +def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]: return [SearchResponse.parse_obj( { - "entry": entries[hit['corpus_id']]['raw'], + "entry": entries[hit['corpus_id']].raw, "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}" }) for hit diff --git a/src/utils/helpers.py b/src/utils/helpers.py index df1899f9..8425a8fa 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -1,8 +1,6 @@ # Standard Packages from pathlib import Path import sys -import time -import hashlib from os.path import join from collections import OrderedDict from typing import Optional, Union @@ -83,38 +81,3 @@ class LRU(OrderedDict): oldest = next(iter(self)) del self[oldest] - -def mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=None): - # Hash all current and previous entries to identify new entries - start = time.time() - current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), current_entries)) - previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), previous_entries)) - end = time.time() - logger.debug(f"Hash previous, current entries: {end - start} seconds") - - start = time.time() - hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) - hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries)) - - # All entries that did not exist in the previous set are to be added - new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes) - # All entries that exist in both current and previous sets are kept - existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes) - - # Mark new entries with no ids for later embeddings generation - new_entries = [ - (None, hash_to_current_entries[entry_hash]) - for entry_hash in new_entry_hashes - ] - # Set id of existing entries to their previous ids to reuse their existing encoded embeddings - existing_entries = [ - (previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash]) - for entry_hash in existing_entry_hashes - ] - - existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0]) - entries_with_ids = existing_entries_sorted + new_entries - end = time.time() - logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds") - - return entries_with_ids \ No newline at end of file diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 84aadc0a..165be0d1 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -1,4 +1,5 @@ # System Packages +import json from pathlib import Path from typing import List, Optional @@ -75,4 +76,28 @@ class FullConfig(ConfigBase): class SearchResponse(ConfigBase): entry: str score: str - additional: Optional[dict] \ No newline at end of file + additional: Optional[dict] + +class Entry(): + raw: str + compiled: str + file: Optional[str] + + def __init__(self, raw: str = None, compiled: str = None, file: Optional[str] = None): + self.raw = raw + self.compiled = compiled + self.file = file + + def to_json(self) -> str: + return json.dumps(self.__dict__, ensure_ascii=False) + + def __repr__(self) -> str: + return self.__dict__.__repr__() + + @classmethod + def from_dict(cls, dictionary: dict): + return cls( + raw=dictionary['raw'], + compiled=dictionary['compiled'], + file=dictionary.get('file', None) + ) \ No newline at end of file diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 345c5c4f..59ef697c 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -8,14 +8,15 @@ import torch # Application Packages from src.search_filter.date_filter import DateFilter +from src.utils.rawconfig import Entry def test_date_filter(): - embeddings = torch.randn(3, 10) entries = [ - {'compiled': '', 'raw': 'Entry with no date'}, - {'compiled': '', 'raw': 'April Fools entry: 1984-04-01'}, - {'compiled': '', 'raw': 'Entry with date:1984-04-02'}] + Entry(compiled='', raw='Entry with no date'), + Entry(compiled='', raw='April Fools entry: 1984-04-01'), + Entry(compiled='', raw='Entry with date:1984-04-02') + ] q_with_no_date_filter = 'head tail' ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries) diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index 3f9c22b3..e6c17299 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -3,6 +3,7 @@ import torch # Application Packages from src.search_filter.file_filter import FileFilter +from src.utils.rawconfig import Entry def test_no_file_filter(): @@ -104,9 +105,10 @@ def test_multiple_file_filter(): def arrange_content(): embeddings = torch.randn(4, 10) entries = [ - {'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'}, - {'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'}, - {'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'}, - {'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}] + Entry(compiled='', raw='First Entry', file= 'file 1.org'), + Entry(compiled='', raw='Second Entry', file= 'file2.org'), + Entry(compiled='', raw='Third Entry', file= 'file 1.org'), + Entry(compiled='', raw='Fourth Entry', file= 'file2.org') + ] - return embeddings, entries + return entries diff --git a/tests/test_image_search.py b/tests/test_image_search.py index e1a56b44..97911164 100644 --- a/tests/test_image_search.py +++ b/tests/test_image_search.py @@ -70,7 +70,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig image_files_url='/static/images', count=1) - actual_image_path = output_directory.joinpath(Path(results[0]["entry"]).name) + actual_image_path = output_directory.joinpath(Path(results[0].entry).name) actual_image = Image.open(actual_image_path) expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name)) diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 584c07b9..e05831a1 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -76,7 +76,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC # Assert # Actual_data should contain "Khoj via Emacs" entry - search_result = results[0]["entry"] + search_result = results[0].entry assert "git clone" in search_result diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index db23c2c6..58069b24 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -1,6 +1,7 @@ # Application Packages from src.search_filter.word_filter import WordFilter from src.utils.config import SearchType +from src.utils.rawconfig import Entry def test_no_word_filter(): @@ -69,9 +70,10 @@ def test_word_include_and_exclude_filter(): def arrange_content(): entries = [ - {'compiled': '', 'raw': 'Minimal Entry'}, - {'compiled': '', 'raw': 'Entry with exclude_word'}, - {'compiled': '', 'raw': 'Entry with include_word'}, - {'compiled': '', 'raw': 'Entry with include_word and exclude_word'}] + Entry(compiled='', raw='Minimal Entry'), + Entry(compiled='', raw='Entry with exclude_word'), + Entry(compiled='', raw='Entry with include_word'), + Entry(compiled='', raw='Entry with include_word and exclude_word') + ] return entries