mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-03 05:29:12 +00:00
Use new Text Entry class to track text entries in Intermediate Format
- Context
- The app maintains all text content in a standard, intermediate format
- The intermediate format was loaded, passed around as a dictionary
for easier, faster updates to the intermediate format schema initially
- The intermediate format is reasonably stable now, given it's usage
by all 3 text content types currently implemented
- Changes
- Concretize text entries into `Entries' class instead of using dictionaries
- Code is updated to load, pass around entries as `Entries' objects
instead of as dictionaries
- `text_search' and `text_to_jsonl' methods are annotated with
type hints for the new `Entries' type
- Code and Tests referencing entries are updated to use class style
access patterns instead of the previous dictionary access patterns
- Move `mark_entries_for_update' method into `TextToJsonl' base class
- This is a more natural location for the method as it is only
(to be) used by `text_to_jsonl' classes
- Avoid circular reference issues on importing `Entries' class
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
# Standard Packages
|
||||
import json
|
||||
import glob
|
||||
import re
|
||||
import logging
|
||||
@@ -7,9 +6,10 @@ import time
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.text_to_jsonl import TextToJsonl
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from src.utils.constants import empty_escape_sequences
|
||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from src.utils.rawconfig import Entry
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,7 +40,7 @@ class BeancountToJsonl(TextToJsonl):
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
end = time.time()
|
||||
logger.debug(f"Identify new or updated transaction: {end - start} seconds")
|
||||
|
||||
@@ -111,17 +111,17 @@ class BeancountToJsonl(TextToJsonl):
|
||||
return entries, dict(transaction_to_file_map)
|
||||
|
||||
@staticmethod
|
||||
def convert_transactions_to_maps(entries: list[str], transaction_to_file_map) -> list[dict]:
|
||||
"Convert each Beancount transaction into a dictionary"
|
||||
entry_maps = []
|
||||
for entry in entries:
|
||||
entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{transaction_to_file_map[entry]}'})
|
||||
def convert_transactions_to_maps(parsed_entries: list[str], transaction_to_file_map) -> list[Entry]:
|
||||
"Convert each parsed Beancount transaction into a Entry"
|
||||
entries = []
|
||||
for parsed_entry in parsed_entries:
|
||||
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{transaction_to_file_map[parsed_entry]}'))
|
||||
|
||||
logger.info(f"Converted {len(entries)} transactions to dictionaries")
|
||||
logger.info(f"Converted {len(parsed_entries)} transactions to dictionaries")
|
||||
|
||||
return entry_maps
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_transaction_maps_to_jsonl(entries: list[dict]) -> str:
|
||||
"Convert each Beancount transaction dictionary to JSON and collate as JSONL"
|
||||
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])
|
||||
def convert_transaction_maps_to_jsonl(entries: list[Entry]) -> str:
|
||||
"Convert each Beancount transaction entry to JSON and collate as JSONL"
|
||||
return ''.join([f'{entry.to_json()}\n' for entry in entries])
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Standard Packages
|
||||
import json
|
||||
import glob
|
||||
import re
|
||||
import logging
|
||||
@@ -7,9 +6,10 @@ import time
|
||||
|
||||
# Internal Packages
|
||||
from src.processor.text_to_jsonl import TextToJsonl
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from src.utils.constants import empty_escape_sequences
|
||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from src.utils.rawconfig import Entry
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,7 +40,7 @@ class MarkdownToJsonl(TextToJsonl):
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
end = time.time()
|
||||
logger.debug(f"Identify new or updated entries: {end - start} seconds")
|
||||
|
||||
@@ -110,17 +110,17 @@ class MarkdownToJsonl(TextToJsonl):
|
||||
return entries, dict(entry_to_file_map)
|
||||
|
||||
@staticmethod
|
||||
def convert_markdown_entries_to_maps(entries: list[str], entry_to_file_map) -> list[dict]:
|
||||
def convert_markdown_entries_to_maps(parsed_entries: list[str], entry_to_file_map) -> list[Entry]:
|
||||
"Convert each Markdown entries into a dictionary"
|
||||
entry_maps = []
|
||||
for entry in entries:
|
||||
entry_maps.append({'compiled': entry, 'raw': entry, 'file': f'{entry_to_file_map[entry]}'})
|
||||
entries = []
|
||||
for parsed_entry in parsed_entries:
|
||||
entries.append(Entry(compiled=parsed_entry, raw=parsed_entry, file=f'{entry_to_file_map[parsed_entry]}'))
|
||||
|
||||
logger.info(f"Converted {len(entries)} markdown entries to dictionaries")
|
||||
logger.info(f"Converted {len(parsed_entries)} markdown entries to dictionaries")
|
||||
|
||||
return entry_maps
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_markdown_maps_to_jsonl(entries):
|
||||
"Convert each Markdown entries to JSON and collate as JSONL"
|
||||
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])
|
||||
def convert_markdown_maps_to_jsonl(entries: list[Entry]):
|
||||
"Convert each Markdown entry to JSON and collate as JSONL"
|
||||
return ''.join([f'{entry.to_json()}\n' for entry in entries])
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Standard Packages
|
||||
import json
|
||||
import glob
|
||||
import logging
|
||||
import time
|
||||
@@ -8,8 +7,9 @@ from typing import Iterable
|
||||
# Internal Packages
|
||||
from src.processor.org_mode import orgnode
|
||||
from src.processor.text_to_jsonl import TextToJsonl
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, mark_entries_for_update
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty
|
||||
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
|
||||
from src.utils.rawconfig import Entry
|
||||
from src.utils import state
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class OrgToJsonl(TextToJsonl):
|
||||
# Define Functions
|
||||
def process(self, previous_entries=None):
|
||||
def process(self, previous_entries: list[Entry]=None):
|
||||
# Extract required fields from config
|
||||
org_files, org_file_filter, output_file = self.config.input_files, self.config.input_filter, self.config.compressed_jsonl
|
||||
index_heading_entries = self.config.index_heading_entries
|
||||
@@ -47,7 +47,7 @@ class OrgToJsonl(TextToJsonl):
|
||||
if not previous_entries:
|
||||
entries_with_ids = list(enumerate(current_entries))
|
||||
else:
|
||||
entries_with_ids = mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
start = time.time()
|
||||
@@ -104,51 +104,48 @@ class OrgToJsonl(TextToJsonl):
|
||||
return entries, dict(entry_to_file_map)
|
||||
|
||||
@staticmethod
|
||||
def convert_org_nodes_to_entries(entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[dict]:
|
||||
"Convert Org-Mode entries into list of dictionary"
|
||||
entry_maps = []
|
||||
for entry in entries:
|
||||
entry_dict = dict()
|
||||
|
||||
if not entry.hasBody and not index_heading_entries:
|
||||
def convert_org_nodes_to_entries(parsed_entries: list[orgnode.Orgnode], entry_to_file_map, index_heading_entries=False) -> list[Entry]:
|
||||
"Convert Org-Mode nodes into list of Entry objects"
|
||||
entries: list[Entry] = []
|
||||
for parsed_entry in parsed_entries:
|
||||
if not parsed_entry.hasBody and not index_heading_entries:
|
||||
# Ignore title notes i.e notes with just headings and empty body
|
||||
continue
|
||||
|
||||
entry_dict["compiled"] = f'{entry.heading}.'
|
||||
compiled = f'{parsed_entry.heading}.'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"Title: {entry.heading}")
|
||||
logger.debug(f"Title: {parsed_entry.heading}")
|
||||
|
||||
if entry.tags:
|
||||
tags_str = " ".join(entry.tags)
|
||||
entry_dict["compiled"] += f'\t {tags_str}.'
|
||||
if parsed_entry.tags:
|
||||
tags_str = " ".join(parsed_entry.tags)
|
||||
compiled += f'\t {tags_str}.'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"Tags: {tags_str}")
|
||||
|
||||
if entry.closed:
|
||||
entry_dict["compiled"] += f'\n Closed on {entry.closed.strftime("%Y-%m-%d")}.'
|
||||
if parsed_entry.closed:
|
||||
compiled += f'\n Closed on {parsed_entry.closed.strftime("%Y-%m-%d")}.'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f'Closed: {entry.closed.strftime("%Y-%m-%d")}')
|
||||
logger.debug(f'Closed: {parsed_entry.closed.strftime("%Y-%m-%d")}')
|
||||
|
||||
if entry.scheduled:
|
||||
entry_dict["compiled"] += f'\n Scheduled for {entry.scheduled.strftime("%Y-%m-%d")}.'
|
||||
if parsed_entry.scheduled:
|
||||
compiled += f'\n Scheduled for {parsed_entry.scheduled.strftime("%Y-%m-%d")}.'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f'Scheduled: {entry.scheduled.strftime("%Y-%m-%d")}')
|
||||
logger.debug(f'Scheduled: {parsed_entry.scheduled.strftime("%Y-%m-%d")}')
|
||||
|
||||
if entry.hasBody:
|
||||
entry_dict["compiled"] += f'\n {entry.body}'
|
||||
if parsed_entry.hasBody:
|
||||
compiled += f'\n {parsed_entry.body}'
|
||||
if state.verbose > 2:
|
||||
logger.debug(f"Body: {entry.body}")
|
||||
logger.debug(f"Body: {parsed_entry.body}")
|
||||
|
||||
if entry_dict:
|
||||
entry_dict["raw"] = f'{entry}'
|
||||
entry_dict["file"] = f'{entry_to_file_map[entry]}'
|
||||
if compiled:
|
||||
entries += [Entry(
|
||||
compiled=compiled,
|
||||
raw=f'{parsed_entry}',
|
||||
file=f'{entry_to_file_map[parsed_entry]}')]
|
||||
|
||||
# Convert Dictionary to JSON and Append to JSONL string
|
||||
entry_maps.append(entry_dict)
|
||||
|
||||
return entry_maps
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def convert_org_entries_to_jsonl(entries: Iterable[dict]) -> str:
|
||||
def convert_org_entries_to_jsonl(entries: Iterable[Entry]) -> str:
|
||||
"Convert each Org-Mode entry to JSON and collate as JSONL"
|
||||
return ''.join([f'{json.dumps(entry_dict, ensure_ascii=False)}\n' for entry_dict in entries])
|
||||
return ''.join([f'{entry_dict.to_json()}\n' for entry_dict in entries])
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable
|
||||
import hashlib
|
||||
import time
|
||||
import logging
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.rawconfig import TextContentConfig
|
||||
from src.utils.rawconfig import Entry, TextContentConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToJsonl(ABC):
|
||||
@@ -11,4 +16,39 @@ class TextToJsonl(ABC):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def process(self, previous_entries: Iterable[tuple[int, dict]]=None) -> list[tuple[int, dict]]: ...
|
||||
def process(self, previous_entries: list[Entry]=None) -> list[tuple[int, Entry]]: ...
|
||||
|
||||
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
|
||||
# Hash all current and previous entries to identify new entries
|
||||
start = time.time()
|
||||
current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(getattr(e, key), encoding='utf-8')).hexdigest(), current_entries))
|
||||
previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(getattr(e, key), encoding='utf-8')).hexdigest(), previous_entries))
|
||||
end = time.time()
|
||||
logger.debug(f"Hash previous, current entries: {end - start} seconds")
|
||||
|
||||
start = time.time()
|
||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
|
||||
|
||||
# All entries that did not exist in the previous set are to be added
|
||||
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
|
||||
# All entries that exist in both current and previous sets are kept
|
||||
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
||||
|
||||
# Mark new entries with -1 id to flag for later embeddings generation
|
||||
new_entries = [
|
||||
(-1, hash_to_current_entries[entry_hash])
|
||||
for entry_hash in new_entry_hashes
|
||||
]
|
||||
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
||||
existing_entries = [
|
||||
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
||||
for entry_hash in existing_entry_hashes
|
||||
]
|
||||
|
||||
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
|
||||
entries_with_ids = existing_entries_sorted + new_entries
|
||||
end = time.time()
|
||||
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
|
||||
|
||||
return entries_with_ids
|
||||
@@ -37,7 +37,7 @@ class DateFilter(BaseFilter):
|
||||
start = time.time()
|
||||
for id, entry in enumerate(entries):
|
||||
# Extract dates from entry
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
|
||||
# Convert date string in entry to unix timestamp
|
||||
try:
|
||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||
|
||||
@@ -24,7 +24,7 @@ class FileFilter(BaseFilter):
|
||||
def load(self, entries, *args, **kwargs):
|
||||
start = time.time()
|
||||
for id, entry in enumerate(entries):
|
||||
self.file_to_entry_map[entry[self.entry_key]].add(id)
|
||||
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
|
||||
end = time.time()
|
||||
logger.debug(f"Created file filter index: {end - start} seconds")
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class WordFilter(BaseFilter):
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
|
||||
# Create map of words to entries they exist in
|
||||
for entry_index, entry in enumerate(entries):
|
||||
for word in re.split(entry_splitter, entry[self.entry_key].lower()):
|
||||
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
|
||||
if word == '':
|
||||
continue
|
||||
self.word_to_entry_index[word].add(entry_index)
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig
|
||||
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
|
||||
from src.utils.jsonl import load_jsonl
|
||||
|
||||
|
||||
@@ -50,12 +50,12 @@ def initialize_model(search_config: TextSearchConfig):
|
||||
return bi_encoder, cross_encoder, top_k
|
||||
|
||||
|
||||
def extract_entries(jsonl_file):
|
||||
def extract_entries(jsonl_file) -> list[Entry]:
|
||||
"Load entries from compressed jsonl"
|
||||
return load_jsonl(jsonl_file)
|
||||
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
||||
|
||||
|
||||
def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate=False):
|
||||
def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, embeddings_file, regenerate=False):
|
||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||
new_entries = []
|
||||
# Load pre-computed embeddings from file if exists and update them if required
|
||||
@@ -64,15 +64,15 @@ def compute_embeddings(entries_with_ids, bi_encoder, embeddings_file, regenerate
|
||||
logger.info(f"Loaded embeddings from {embeddings_file}")
|
||||
|
||||
# Encode any new entries in the corpus and update corpus embeddings
|
||||
new_entries = [entry['compiled'] for id, entry in entries_with_ids if id is None]
|
||||
new_entries = [entry.compiled for id, entry in entries_with_ids if id == -1]
|
||||
if new_entries:
|
||||
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
existing_entry_ids = [id for id, _ in entries_with_ids if id is not None]
|
||||
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
||||
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor()
|
||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||
# Else compute the corpus embeddings from scratch
|
||||
else:
|
||||
new_entries = [entry['compiled'] for _, entry in entries_with_ids]
|
||||
new_entries = [entry.compiled for _, entry in entries_with_ids]
|
||||
corpus_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||
|
||||
# Save regenerated or updated embeddings to file
|
||||
@@ -133,7 +133,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
start = time.time()
|
||||
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
||||
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
@@ -153,7 +153,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||
return hits, entries
|
||||
|
||||
|
||||
def render_results(hits, entries, count=5, display_biencoder_results=False):
|
||||
def render_results(hits, entries: list[Entry], count=5, display_biencoder_results=False):
|
||||
"Render the Results returned by Search for the Query"
|
||||
if display_biencoder_results:
|
||||
# Output of top hits from bi-encoder
|
||||
@@ -161,20 +161,20 @@ def render_results(hits, entries, count=5, display_biencoder_results=False):
|
||||
print(f"Top-{count} Bi-Encoder Retrieval hits")
|
||||
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
||||
for hit in hits[0:count]:
|
||||
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']]['compiled']}")
|
||||
print(f"Score: {hit['score']:.3f}\n------------\n{entries[hit['corpus_id']].compiled}")
|
||||
|
||||
# Output of top hits from re-ranker
|
||||
print("\n-------------------------\n")
|
||||
print(f"Top-{count} Cross-Encoder Re-ranker hits")
|
||||
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
||||
for hit in hits[0:count]:
|
||||
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']]['compiled']}")
|
||||
print(f"CrossScore: {hit['cross-score']:.3f}\n-----------------\n{entries[hit['corpus_id']].compiled}")
|
||||
|
||||
|
||||
def collate_results(hits, entries, count=5) -> list[SearchResponse]:
|
||||
def collate_results(hits, entries: list[Entry], count=5) -> list[SearchResponse]:
|
||||
return [SearchResponse.parse_obj(
|
||||
{
|
||||
"entry": entries[hit['corpus_id']]['raw'],
|
||||
"entry": entries[hit['corpus_id']].raw,
|
||||
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
|
||||
})
|
||||
for hit
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# Standard Packages
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import time
|
||||
import hashlib
|
||||
from os.path import join
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Union
|
||||
@@ -83,38 +81,3 @@ class LRU(OrderedDict):
|
||||
oldest = next(iter(self))
|
||||
del self[oldest]
|
||||
|
||||
|
||||
def mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=None):
|
||||
# Hash all current and previous entries to identify new entries
|
||||
start = time.time()
|
||||
current_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), current_entries))
|
||||
previous_entry_hashes = list(map(lambda e: hashlib.md5(bytes(e[key], encoding='utf-8')).hexdigest(), previous_entries))
|
||||
end = time.time()
|
||||
logger.debug(f"Hash previous, current entries: {end - start} seconds")
|
||||
|
||||
start = time.time()
|
||||
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
|
||||
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
|
||||
|
||||
# All entries that did not exist in the previous set are to be added
|
||||
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
|
||||
# All entries that exist in both current and previous sets are kept
|
||||
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
|
||||
|
||||
# Mark new entries with no ids for later embeddings generation
|
||||
new_entries = [
|
||||
(None, hash_to_current_entries[entry_hash])
|
||||
for entry_hash in new_entry_hashes
|
||||
]
|
||||
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
|
||||
existing_entries = [
|
||||
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
|
||||
for entry_hash in existing_entry_hashes
|
||||
]
|
||||
|
||||
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
|
||||
entries_with_ids = existing_entries_sorted + new_entries
|
||||
end = time.time()
|
||||
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
|
||||
|
||||
return entries_with_ids
|
||||
@@ -1,4 +1,5 @@
|
||||
# System Packages
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -75,4 +76,28 @@ class FullConfig(ConfigBase):
|
||||
class SearchResponse(ConfigBase):
|
||||
entry: str
|
||||
score: str
|
||||
additional: Optional[dict]
|
||||
additional: Optional[dict]
|
||||
|
||||
class Entry():
|
||||
raw: str
|
||||
compiled: str
|
||||
file: Optional[str]
|
||||
|
||||
def __init__(self, raw: str = None, compiled: str = None, file: Optional[str] = None):
|
||||
self.raw = raw
|
||||
self.compiled = compiled
|
||||
self.file = file
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.__dict__, ensure_ascii=False)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__dict__.__repr__()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dictionary: dict):
|
||||
return cls(
|
||||
raw=dictionary['raw'],
|
||||
compiled=dictionary['compiled'],
|
||||
file=dictionary.get('file', None)
|
||||
)
|
||||
@@ -8,14 +8,15 @@ import torch
|
||||
|
||||
# Application Packages
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
from src.utils.rawconfig import Entry
|
||||
|
||||
|
||||
def test_date_filter():
|
||||
embeddings = torch.randn(3, 10)
|
||||
entries = [
|
||||
{'compiled': '', 'raw': 'Entry with no date'},
|
||||
{'compiled': '', 'raw': 'April Fools entry: 1984-04-01'},
|
||||
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
|
||||
Entry(compiled='', raw='Entry with no date'),
|
||||
Entry(compiled='', raw='April Fools entry: 1984-04-01'),
|
||||
Entry(compiled='', raw='Entry with date:1984-04-02')
|
||||
]
|
||||
|
||||
q_with_no_date_filter = 'head tail'
|
||||
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch
|
||||
|
||||
# Application Packages
|
||||
from src.search_filter.file_filter import FileFilter
|
||||
from src.utils.rawconfig import Entry
|
||||
|
||||
|
||||
def test_no_file_filter():
|
||||
@@ -104,9 +105,10 @@ def test_multiple_file_filter():
|
||||
def arrange_content():
|
||||
embeddings = torch.randn(4, 10)
|
||||
entries = [
|
||||
{'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'},
|
||||
{'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'},
|
||||
{'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'},
|
||||
{'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}]
|
||||
Entry(compiled='', raw='First Entry', file= 'file 1.org'),
|
||||
Entry(compiled='', raw='Second Entry', file= 'file2.org'),
|
||||
Entry(compiled='', raw='Third Entry', file= 'file 1.org'),
|
||||
Entry(compiled='', raw='Fourth Entry', file= 'file2.org')
|
||||
]
|
||||
|
||||
return embeddings, entries
|
||||
return entries
|
||||
|
||||
@@ -70,7 +70,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
||||
image_files_url='/static/images',
|
||||
count=1)
|
||||
|
||||
actual_image_path = output_directory.joinpath(Path(results[0]["entry"]).name)
|
||||
actual_image_path = output_directory.joinpath(Path(results[0].entry).name)
|
||||
actual_image = Image.open(actual_image_path)
|
||||
expected_image = Image.open(content_config.image.input_directories[0].joinpath(expected_image_name))
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
|
||||
|
||||
# Assert
|
||||
# Actual_data should contain "Khoj via Emacs" entry
|
||||
search_result = results[0]["entry"]
|
||||
search_result = results[0].entry
|
||||
assert "git clone" in search_result
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Application Packages
|
||||
from src.search_filter.word_filter import WordFilter
|
||||
from src.utils.config import SearchType
|
||||
from src.utils.rawconfig import Entry
|
||||
|
||||
|
||||
def test_no_word_filter():
|
||||
@@ -69,9 +70,10 @@ def test_word_include_and_exclude_filter():
|
||||
|
||||
def arrange_content():
|
||||
entries = [
|
||||
{'compiled': '', 'raw': 'Minimal Entry'},
|
||||
{'compiled': '', 'raw': 'Entry with exclude_word'},
|
||||
{'compiled': '', 'raw': 'Entry with include_word'},
|
||||
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}]
|
||||
Entry(compiled='', raw='Minimal Entry'),
|
||||
Entry(compiled='', raw='Entry with exclude_word'),
|
||||
Entry(compiled='', raw='Entry with include_word'),
|
||||
Entry(compiled='', raw='Entry with include_word and exclude_word')
|
||||
]
|
||||
|
||||
return entries
|
||||
|
||||
Reference in New Issue
Block a user