Use new Text Entry class to track text entries in Intermediate Format

- Context
  - The app maintains all text content in a standard, intermediate format
  - The intermediate format was loaded, passed around as a dictionary
    for easier, faster updates to the intermediate format schema initially
  - The intermediate format is reasonably stable now, given it's usage
    by all 3 text content types currently implemented

- Changes
  - Concretize text entries into `Entries' class instead of using dictionaries
    - Code is updated to load, pass around entries as `Entries' objects
      instead of as dictionaries
    - `text_search' and `text_to_jsonl' methods are annotated with
       type hints for the new `Entries' type
    - Code and Tests referencing entries are updated to use class style
      access patterns instead of the previous dictionary access patterns

  - Move `mark_entries_for_update' method into `TextToJsonl' base class
    - This is a more natural location for the method as it is only
      (to be) used by `text_to_jsonl' classes
    - Avoid circular reference issues on importing `Entries' class
This commit is contained in:
Debanjum Singh Solanky
2022-09-15 23:34:43 +03:00
parent 99754970ab
commit 7e9298f315
15 changed files with 161 additions and 131 deletions

View File

@@ -1,5 +1,4 @@
# Standard Packages # Standard Packages
import json
import glob import glob
import re import re
import logging import logging
@@ -7,9 +6,10 @@ import time
# Internal Packages # Internal Packages
from src.processor.text_to_jsonl import TextToJsonl from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.constants import empty_escape_sequences from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import Entry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -40,7 +40,7 @@ class BeancountToJsonl(TextToJsonl):
if not previous_entries: if not previous_entries:
entries_with_ids = list(enumerate(current_entries)) entries_with_ids = list(enumerate(current_entries))
else: else:
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time() end = time.time()
logger.debug(f"Identify new or updated transaction: {end - start} seconds") logger.debug(f"Identify new or updated transaction: {end - start} seconds")
@@ -111,17 +111,17 @@ class BeancountToJsonl(TextToJsonl):
return entries, dict(transaction_to_file_map) return entries, dict(transaction_to_file_map)
@staticmethod @staticmethod
def convert_transactions_to_maps(entries: list[str], transaction_to_file_map) -> list[dict]: def convert_transactions_to_maps(parsed_entries: list[str], transaction_to_file_map) -> list[Entry]:
"Convert each Beancount transaction into a dictionary" "Convert each parsed Beancount transaction into a Entry"
entry_maps = [] entries = []
for entry in entries: for parsed_entry in parsed_entries:
entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{transaction_to_file_map[entry]}'}) entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{transaction_to_file_map[parsed_entry]}'))
logger.info(f"Converted {len(entries)} transactions to dictionaries") logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries")
return entry_maps return entries
@staticmethod @staticmethod
def convert_transaction_maps_to_jsonl(entries: list[dict]) -> str: def convert_transaction_maps_to_jsonl(entries: list[Entry]) -> str:
"Convert each Beancount transaction dictionary to JSON and collate as JSONL" "Convert each Beancount transaction entry to JSON and collate as JSONL"
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) return ''.join([f'{entry.to_json()}\n' for entry in entries])

View File

@@ -1,5 +1,4 @@
# Standard Packages # Standard Packages
import json
import glob import glob
import re import re
import logging import logging
@@ -7,9 +6,10 @@ import time
# Internal Packages # Internal Packages
from src.processor.text_to_jsonl import TextToJsonl from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.constants import empty_escape_sequences from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import Entry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -40,7 +40,7 @@ class MarkdownToJsonl(TextToJsonl):
if not previous_entries: if not previous_entries:
entries_with_ids = list(enumerate(current_entries)) entries_with_ids = list(enumerate(current_entries))
else: else:
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time() end = time.time()
logger.debug(f"Identify new or updated entries: {end - start} seconds") logger.debug(f"Identify new or updated entries: {end - start} seconds")
@@ -110,17 +110,17 @@ class MarkdownToJsonl(TextToJsonl):
return entries, dict(entry_to_file_map) return entries, dict(entry_to_file_map)
@staticmethod @staticmethod
def convert_markdown_entries_to_maps(entries: list[str], entry_to_file_map) -> list[dict]: def convert_markdown_entries_to_maps(parsed_entries: list[str], entry_to_file_map) -> list[Entry]:
"Convert each Markdown entries into a dictionary" "Convert each Markdown entries into a dictionary"
entry_maps = [] entries = []
for entry in entries: for parsed_entry in parsed_entries:
entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{entry_to_file_map[entry]}'}) entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{entry_to_file_map[parsed_entry]}'))
logger.info(f"Converted {len(entries)} markdown entries to dictionaries") logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
return entry_maps return entries
@staticmethod @staticmethod
def convert_markdown_maps_to_jsonl(entries): def convert_markdown_maps_to_jsonl(entries: list[Entry]):
"Convert each Markdown entries to JSON and collate as JSONL" "Convert each Markdown entry to JSON and collate as JSONL"
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) return ''.join([f'{entry.to_json()}\n' for entry in entries])

View File

@@ -1,5 +1,4 @@
# Standard Packages # Standard Packages
import json
import glob import glob
import logging import logging
import time import time
@@ -8,8 +7,9 @@ from typing import Iterable
# Internal Packages # Internal Packages
from src.processor.org_mode import orgnode from src.processor.org_mode import orgnode
from src.processor.text_to_jsonl import TextToJsonl from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import Entry
from src.utils import state from src.utils import state
@@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
class OrgToJsonl(TextToJsonl): class OrgToJsonl(TextToJsonl):
# Define Functions # Define Functions
def process(self, previous_entries=None): def process(self, previous_entries: list[Entry]=None):
# Extract required fields from config # Extract required fields from config
org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
index_heading_entries = self.config.index_heading_entries index_heading_entries = self.config.index_heading_entries
@@ -47,7 +47,7 @@ class OrgToJsonl(TextToJsonl):
if not previous_entries: if not previous_entries:
entries_with_ids = list(enumerate(current_entries)) entries_with_ids = list(enumerate(current_entries))
else: else:
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
start = time.time() start = time.time()
@@ -104,51 +104,48 @@ class OrgToJsonl(TextToJsonl):
return entries, dict(entry_to_file_map) return entries, dict(entry_to_file_map)
@staticmethod @staticmethod
def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[dict]: def convert_org_nodes_to_entries(parsed_entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[Entry]:
"Convert Org-Mode entries into list of dictionary" "Convert Org-Mode nodes into list of Entry objects"
entry_maps = [] entries: list[Entry] = []
for entry in entries: for parsed_entry in parsed_entries:
entry_dict = dict() if not parsed_entry.hasBody and not index_heading_entries:
if not entry.hasBody and not index_heading_entries:
# Ignore title notes i.e notes with just headings and empty body # Ignore title notes i.e notes with just headings and empty body
continue continue
entry_dict["compiled"] = f'{entry.heading}.' compiled = f'{parsed_entry.heading}.'
if state.verbose > 2: if state.verbose > 2:
logger.debug(f"Title: {entry.heading}") logger.debug(f"Title: {parsed_entry.heading}")
if entry.tags: if parsed_entry.tags:
tags_str = " ".join(entry.tags) tags_str = " ".join(parsed_entry.tags)
entry_dict["compiled"] += f'\t {tags_str}.' compiled += f'\t {tags_str}.'
if state.verbose > 2: if state.verbose > 2:
logger.debug(f"Tags: {tags_str}") logger.debug(f"Tags: {tags_str}")
if entry.closed: if parsed_entry.closed:
entry_dict["compiled"] += f'\n Closed on {entry.closed.strftime("%Y-%m-%d")}.' compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
if state.verbose > 2: if state.verbose > 2:
logger.debug(f'Closed: {entry.closed.strftime("%Y-%m-%d")}') logger.debug(f'Closed: {parsed_entry.closed.strftime("%Y-%m-%d")}')
if entry.scheduled: if parsed_entry.scheduled:
entry_dict["compiled"] += f'\n Scheduled for {entry.scheduled.strftime("%Y-%m-%d")}.' compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
if state.verbose > 2: if state.verbose > 2:
logger.debug(f'Scheduled: {entry.scheduled.strftime("%Y-%m-%d")}') logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
if entry.hasBody: if parsed_entry.hasBody:
entry_dict["compiled"] += f'\n {entry.body}' compiled += f'\n {parsed_entry.body}'
if state.verbose > 2: if state.verbose > 2:
logger.debug(f"Body: {entry.body}") logger.debug(f"Body: {parsed_entry.body}")
if entry_dict: if compiled:
entry_dict["raw"] = f'{entry}' entries += [Entry(
entry_dict["file"] = f'{entry_to_file_map[entry]}' compiled=compiled,
raw=f'{parsed_entry}',
file=f'{entry_to_file_map[parsed_entry]}')]
# Convert Dictionary to JSON and Append to JSONL string return entries
entry_maps.append(entry_dict)
return entry_maps
@staticmethod @staticmethod
def convert_org_entries_to_jsonl(entries: Iterable[dict]) -> str: def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
"Convert each Org-Mode entry to JSON and collate as JSONL" "Convert each Org-Mode entry to JSON and collate as JSONL"
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries]) return ''.join([f'{entry_dict.to_json()}\n' for entry_dict in entries])

View File

@@ -1,9 +1,14 @@
# Standard Packages # Standard Packages
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Iterable import hashlib
import time
import logging
# Internal Packages # Internal Packages
from src.utils.rawconfig import TextContentConfig from src.utils.rawconfig import Entry, TextContentConfig
logger = logging.getLogger(__name__)
class TextToJsonl(ABC): class TextToJsonl(ABC):
@@ -11,4 +16,39 @@ class TextToJsonl(ABC):
self.config = config self.config = config
@abstractmethod @abstractmethod
def process(self, previous_entries: Iterable[tuple[int, dict]]=None) -> list[tuple[int, dict]]: ... def process(self, previous_entries: list[Entry]=None) -> list[tuple[int, Entry]]: ...
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
# Hash all current and previous entries to identify new entries
start = time.time()
current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(getattr(e, key), encoding='utf-8')).hexdigest(), current_entries))
previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(getattr(e, key), encoding='utf-8')).hexdigest(), previous_entries))
end = time.time()
logger.debug(f"Hash previous, current entries: {end - start} seconds")
start = time.time()
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
# All entries that did not exist in the previous set are to be added
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
# All entries that exist in both current and previous sets are kept
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# Mark new entries with -1 id to flag for later embeddings generation
new_entries = [
(-1, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
for entry_hash in existing_entry_hashes
]
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries
end = time.time()
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
return entries_with_ids

View File

@@ -37,7 +37,7 @@ class DateFilter(BaseFilter):
start = time.time() start = time.time()
for id, entry in enumerate(entries): for id, entry in enumerate(entries):
# Extract dates from entry # Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
# Convert date string in entry to unix timestamp # Convert date string in entry to unix timestamp
try: try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()

View File

@@ -24,7 +24,7 @@ class FileFilter(BaseFilter):
def load(self, entries, *args, **kwargs): def load(self, entries, *args, **kwargs):
start = time.time() start = time.time()
for id, entry in enumerate(entries): for id, entry in enumerate(entries):
self.file_to_entry_map[entry[self.entry_key]].add(id) self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
end = time.time() end = time.time()
logger.debug(f"Created file filter index: {end - start} seconds") logger.debug(f"Created file filter index: {end - start} seconds")

View File

@@ -29,7 +29,7 @@ class WordFilter(BaseFilter):
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
# Create map of words to entries they exist in # Create map of words to entries they exist in
for entry_index, entry in enumerate(entries): for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, entry[self.entry_key].lower()): for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
if word == '': if word == '':
continue continue
self.word_to_entry_index[word].add(entry_index) self.word_to_entry_index[word].add(entry_index)

View File

@@ -13,7 +13,7 @@ from src.search_filter.base_filter import BaseFilter
from src.utils import state from src.utils import state
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
from src.utils.config import TextSearchModel from src.utils.config import TextSearchModel
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
from src.utils.jsonl import load_jsonl from src.utils.jsonl import load_jsonl
@@ -50,12 +50,12 @@ def initialize_model(search_config: TextSearchConfig):
return bi_encoder, cross_encoder, top_k return bi_encoder, cross_encoder, top_k
def extract_entries(jsonl_file): def extract_entries(jsonl_file) -> list[Entry]:
"Load entries from compressed jsonl" "Load entries from compressed jsonl"
return load_jsonl(jsonl_file) return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False): def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings" "Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
new_entries = [] new_entries = []
# Load pre-computed embeddings from file if exists and update them if required # Load pre-computed embeddings from file if exists and update them if required
@@ -64,15 +64,15 @@ def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate
logger.info(f"Loaded embeddings from {embeddings_file}") logger.info(f"Loaded embeddings from {embeddings_file}")
# Encode any new entries in the corpus and update corpus embeddings # Encode any new entries in the corpus and update corpus embeddings
new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None] new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
if new_entries: if new_entries:
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
existing_entry_ids = [id for id, _ in entries_with_ids if id is not None] existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor() existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor()
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
# Else compute the corpus embeddings from scratch # Else compute the corpus embeddings from scratch
else: else:
new_entries = [entry['compiled'] for _, entry in entries_with_ids] new_entries = [entry.compiled for _, entry in entries_with_ids]
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
# Save regenerated or updated embeddings to file # Save regenerated or updated embeddings to file
@@ -133,7 +133,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
if rank_results: if rank_results:
start = time.time() start = time.time()
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits] cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp) cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time() end = time.time()
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}") logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
@@ -153,7 +153,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
return hits, entries return hits, entries
def render_results(hits, entries, count=5, display_biencoder_results=False): def render_results(hits, entries: list[Entry], count=5, display_biencoder_results=False):
"Render the Results returned by Search for the Query" "Render the Results returned by Search for the Query"
if display_biencoder_results: if display_biencoder_results:
# Output of top hits from bi-encoder # Output of top hits from bi-encoder
@@ -161,20 +161,20 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
print(f"Top-{count} Bi-Encoder Retrieval hits") print(f"Top-{count} Bi-Encoder Retrieval hits")
hits = sorted(hits, key=lambda x: x['score'], reverse=True) hits = sorted(hits, key=lambda x: x['score'], reverse=True)
for hit in hits[0:count]: for hit in hits[0:count]:
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}") print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']].compiled}")
# Output of top hits from re-ranker # Output of top hits from re-ranker
print("\n-------------------------\n") print("\n-------------------------\n")
print(f"Top-{count} Cross-Encoder Re-ranker hits") print(f"Top-{count} Cross-Encoder Re-ranker hits")
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
for hit in hits[0:count]: for hit in hits[0:count]:
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}") print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']].compiled}")
def collate_results(hits, entries, count=5) -> list[SearchResponse]: def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]:
return [SearchResponse.parse_obj( return [SearchResponse.parse_obj(
{ {
"entry": entries[hit['corpus_id']]['raw'], "entry": entries[hit['corpus_id']].raw,
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}" "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
}) })
for hit for hit

View File

@@ -1,8 +1,6 @@
# Standard Packages # Standard Packages
from pathlib import Path from pathlib import Path
import sys import sys
import time
import hashlib
from os.path import join from os.path import join
from collections import OrderedDict from collections import OrderedDict
from typing import Optional, Union from typing import Optional, Union
@@ -83,38 +81,3 @@ class LRU(OrderedDict):
oldest = next(iter(self)) oldest = next(iter(self))
del self[oldest] del self[oldest]
def mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=None):
# Hash all current and previous entries to identify new entries
start = time.time()
current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), current_entries))
previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), previous_entries))
end = time.time()
logger.debug(f"Hash previous, current entries: {end - start} seconds")
start = time.time()
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
# All entries that did not exist in the previous set are to be added
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
# All entries that exist in both current and previous sets are kept
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# Mark new entries with no ids for later embeddings generation
new_entries = [
(None, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
for entry_hash in existing_entry_hashes
]
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries
end = time.time()
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
return entries_with_ids

View File

@@ -1,4 +1,5 @@
# System Packages # System Packages
import json
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
@@ -75,4 +76,28 @@ class FullConfig(ConfigBase):
class SearchResponse(ConfigBase): class SearchResponse(ConfigBase):
entry: str entry: str
score: str score: str
additional: Optional[dict] additional: Optional[dict]
class Entry():
raw: str
compiled: str
file: Optional[str]
def __init__(self, raw: str = None, compiled: str = None, file: Optional[str] = None):
self.raw = raw
self.compiled = compiled
self.file = file
def to_json(self) -> str:
return json.dumps(self.__dict__, ensure_ascii=False)
def __repr__(self) -> str:
return self.__dict__.__repr__()
@classmethod
def from_dict(cls, dictionary: dict):
return cls(
raw=dictionary['raw'],
compiled=dictionary['compiled'],
file=dictionary.get('file', None)
)

View File

@@ -8,14 +8,15 @@ import torch
# Application Packages # Application Packages
from src.search_filter.date_filter import DateFilter from src.search_filter.date_filter import DateFilter
from src.utils.rawconfig import Entry
def test_date_filter(): def test_date_filter():
embeddings = torch.randn(3, 10)
entries = [ entries = [
{'compiled': '', 'raw': 'Entry with no date'}, Entry(compiled='', raw='Entry with no date'),
{'compiled': '', 'raw': 'April Fools entry: 1984-04-01'}, Entry(compiled='', raw='April Fools entry: 1984-04-01'),
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}] Entry(compiled='', raw='Entry with date:1984-04-02')
]
q_with_no_date_filter = 'head tail' q_with_no_date_filter = 'head tail'
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries) ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)

View File

@@ -3,6 +3,7 @@ import torch
# Application Packages # Application Packages
from src.search_filter.file_filter import FileFilter from src.search_filter.file_filter import FileFilter
from src.utils.rawconfig import Entry
def test_no_file_filter(): def test_no_file_filter():
@@ -104,9 +105,10 @@ def test_multiple_file_filter():
def arrange_content(): def arrange_content():
embeddings = torch.randn(4, 10) embeddings = torch.randn(4, 10)
entries = [ entries = [
{'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'}, Entry(compiled='', raw='First Entry', file= 'file 1.org'),
{'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'}, Entry(compiled='', raw='Second Entry', file= 'file2.org'),
{'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'}, Entry(compiled='', raw='Third Entry', file= 'file 1.org'),
{'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}] Entry(compiled='', raw='Fourth Entry', file= 'file2.org')
]
return embeddings, entries return entries

View File

@@ -70,7 +70,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
image_files_url='/static/images', image_files_url='/static/images',
count=1) count=1)
actual_image_path = output_directory.joinpath(Path(results[0]["entry"]).name) actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
actual_image = Image.open(actual_image_path) actual_image = Image.open(actual_image_path)
expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name)) expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name))

View File

@@ -76,7 +76,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
# Assert # Assert
# Actual_data should contain "Khoj via Emacs" entry # Actual_data should contain "Khoj via Emacs" entry
search_result = results[0]["entry"] search_result = results[0].entry
assert "git clone" in search_result assert "git clone" in search_result

View File

@@ -1,6 +1,7 @@
# Application Packages # Application Packages
from src.search_filter.word_filter import WordFilter from src.search_filter.word_filter import WordFilter
from src.utils.config import SearchType from src.utils.config import SearchType
from src.utils.rawconfig import Entry
def test_no_word_filter(): def test_no_word_filter():
@@ -69,9 +70,10 @@ def test_word_include_and_exclude_filter():
def arrange_content(): def arrange_content():
entries = [ entries = [
{'compiled': '', 'raw': 'Minimal Entry'}, Entry(compiled='', raw='Minimal Entry'),
{'compiled': '', 'raw': 'Entry with exclude_word'}, Entry(compiled='', raw='Entry with exclude_word'),
{'compiled': '', 'raw': 'Entry with include_word'}, Entry(compiled='', raw='Entry with include_word'),
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}] Entry(compiled='', raw='Entry with include_word and exclude_word')
]
return entries return entries