diff --git a/src/configure.py b/src/configure.py index 0e1e3333..938062eb 100644 --- a/src/configure.py +++ b/src/configure.py @@ -40,22 +40,22 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, # Initialize Org Notes Search if (t == SearchType.Org or t == None) and config.content_type.org: # Extract Entries, Generate Notes Embeddings - model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate) + model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, search_type=SearchType.Org, regenerate=regenerate) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings - model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate) + model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, search_type=SearchType.Music, regenerate=regenerate) # Initialize Markdown Search if (t == SearchType.Markdown or t == None) and config.content_type.markdown: # Extract Entries, Generate Markdown Embeddings - model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate) + model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, search_type=SearchType.Markdown, regenerate=regenerate) # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: # Extract Entries, Generate Ledger Embeddings - model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate) + model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, search_type=SearchType.Ledger, regenerate=regenerate) # Initialize Image Search if (t == SearchType.Image or t == None) and config.content_type.image: diff --git a/src/router.py b/src/router.py index 412692ff..127623c6 100644 --- a/src/router.py +++ b/src/router.py @@ -65,7 +65,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Org or t == None) and state.model.orgmode_search: # query org-mode notes query_start = time.time() - hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r) query_end = time.time() # collate and return results @@ -76,7 +76,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Music or t == None) and state.model.music_search: # query music library query_start = time.time() - hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r) query_end = time.time() # collate and return results @@ -87,7 +87,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Markdown or t == None) and state.model.markdown_search: # query markdown files query_start = time.time() - hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r) query_end = time.time() # collate and return results @@ -98,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Ledger or t == None) and state.model.ledger_search: # query transactions query_start = time.time() - hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r) query_end = time.time() # collate and return results diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index dc70ca29..cab47cbb 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -3,6 +3,7 @@ import re from datetime import timedelta, datetime from dateutil.relativedelta import relativedelta, MO from math import inf +from copy import deepcopy # External Packages import torch @@ -17,29 +18,42 @@ class DateFilter: # - dt:"2 years ago" date_regex = r"dt([:><=]{1,2})\"(.*?)\"" + + def __init__(self, entry_key='raw'): + self.entry_key = entry_key + + + def load(*args, **kwargs): + pass + + def can_filter(self, raw_query): "Check if query contains date filters" return self.extract_date_range(raw_query) is not None - def filter(self, query, entries, embeddings, entry_key='raw'): + def apply(self, query, raw_entries, raw_embeddings): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query query_daterange = self.extract_date_range(query) # if no date in query, return all entries if query_daterange is None: - return query, entries, embeddings + return query, raw_entries, raw_embeddings # remove date range filter from query query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces + # deep copy original embeddings, entries before filtering + embeddings= deepcopy(raw_embeddings) + entries = deepcopy(raw_entries) + # find entries containing any dates that fall with date range specified in query entries_to_include = set() for id, entry in enumerate(entries): # Extract dates from entry - for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]): + for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): # Convert date string in entry to unix timestamp try: date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index b7bb6754..7a26f830 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -1,57 +1,121 @@ # Standard Packages import re +import time +import pickle +import logging # External Packages import torch +# Internal Packages +from src.utils.helpers import LRU, resolve_absolute_path +from src.utils.config import SearchType + + +logger = logging.getLogger(__name__) + class ExplicitFilter: + # Filter Regex + required_regex = r'\+"(\w+)" ?' + blocked_regex = r'\-"(\w+)" ?' + + def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'): + self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") + self.entry_key = entry_key + self.search_type = search_type + self.word_to_entry_index = dict() + self.cache = LRU() + + + def load(self, entries, regenerate=False): + if self.filter_file.exists() and not regenerate: + start = time.time() + with self.filter_file.open('rb') as f: + self.word_to_entry_index = pickle.load(f) + end = time.time() + logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds") + else: + start = time.time() + self.cache = {} # Clear cache on (re-)generating entries_by_word_set + entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' + # Create map of words to entries they exist in + for entry_index, entry in enumerate(entries): + for word in re.split(entry_splitter, entry[self.entry_key].lower()): + if word == '': + continue + if word not in self.word_to_entry_index: + self.word_to_entry_index[word] = set() + self.word_to_entry_index[word].add(entry_index) + + with self.filter_file.open('wb') as f: + pickle.dump(self.word_to_entry_index, f) + end = time.time() + logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") + + return self.word_to_entry_index + + def can_filter(self, raw_query): "Check if query contains explicit filters" # Extract explicit query portion with required, blocked words to filter from natural query - required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) - blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + required_words = re.findall(self.required_regex, raw_query) + blocked_words = re.findall(self.blocked_regex, raw_query) return len(required_words) != 0 or len(blocked_words) != 0 - def filter(self, raw_query, entries, embeddings, entry_key='raw'): + def apply(self, raw_query, raw_entries, raw_embeddings): "Find entries containing required and not blocked words specified in query" # Separate natural query from explicit required, blocked words filters - query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) - required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) - blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + start = time.time() + + required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)]) + blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)]) + query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip() + + end = time.time() + logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") if len(required_words) == 0 and len(blocked_words) == 0: + return query, raw_entries, raw_embeddings + + # Return item from cache if exists + cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) + if cache_key in self.cache: + logger.info(f"Explicit filter results from cache") + entries, embeddings = self.cache[cache_key] return query, entries, embeddings - # convert each entry to a set of words - # split on fullstop, comma, colon, tab, newline or any brackets - entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' - entries_by_word_set = [set(word.lower() - for word - in re.split(entry_splitter, entry[entry_key]) - if word != "") - for entry in entries] + if not self.word_to_entry_index: + self.load(raw_entries, regenerate=False) - # track id of entries to exclude - entries_to_exclude = set() + start = time.time() - # mark entries that do not contain all required_words for exclusion + # mark entries that contain all required_words for inclusion + entries_with_all_required_words = set(range(len(raw_entries))) if len(required_words) > 0: - for id, words_in_entry in enumerate(entries_by_word_set): - if not required_words.issubset(words_in_entry): - entries_to_exclude.add(id) + entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words]) # mark entries that contain any blocked_words for exclusion + entries_with_any_blocked_words = set() if len(blocked_words) > 0: - for id, words_in_entry in enumerate(entries_by_word_set): - if words_in_entry.intersection(blocked_words): - entries_to_exclude.add(id) + entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words]) - # delete entries (and their embeddings) marked for exclusion - for id in sorted(list(entries_to_exclude), reverse=True): - del entries[id] - embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) + end = time.time() + logger.debug(f"Mark entries satisfying filter: {end - start} seconds") + + # get entries (and their embeddings) satisfying inclusion and exclusion filters + start = time.time() + + included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words + entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices] + embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) + + end = time.time() + logger.debug(f"Keep entries satisfying filter: {end - start} seconds") + + # Cache results + self.cache[cache_key] = entries, embeddings return query, entries, embeddings diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index fe066033..742ff5ed 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -3,16 +3,17 @@ import argparse import pathlib import logging import time -from copy import deepcopy # External Packages import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util +from src.search_filter.date_filter import DateFilter +from src.search_filter.explicit_filter import ExplicitFilter # Internal Packages from src.utils import state from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model -from src.utils.config import TextSearchModel +from src.utils.config import SearchType, TextSearchModel from src.utils.rawconfig import TextSearchConfig, TextContentConfig from src.utils.jsonl import load_jsonl @@ -73,26 +74,15 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False): return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []): +def query(raw_query: str, model: TextSearchModel, rank_results=False): "Search for entries that answer the query" - query = raw_query - - # Use deep copy of original embeddings, entries to filter if query contains filters - start = time.time() - filters_in_query = [filter for filter in filters if filter.can_filter(query)] - if filters_in_query: - corpus_embeddings = deepcopy(model.corpus_embeddings) - entries = deepcopy(model.entries) - else: - corpus_embeddings = model.corpus_embeddings - entries = model.entries - end = time.time() - logger.debug(f"Copy Time: {end - start:.3f} seconds") + query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings # Filter query, entries and embeddings before semantic search start = time.time() + filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] for filter in filters_in_query: - query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings) + query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings) end = time.time() logger.debug(f"Filter Time: {end - start:.3f} seconds") @@ -163,7 +153,7 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel: +def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) @@ -180,7 +170,12 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon config.embeddings_file = resolve_absolute_path(config.embeddings_file) corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate) - return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k) + filter_directory = resolve_absolute_path(config.compressed_jsonl.parent) + filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)] + for filter in filters: + filter.load(entries, regenerate=regenerate) + + return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k) if __name__ == '__main__': diff --git a/src/utils/config.py b/src/utils/config.py index 6e69d8b4..a4de6b81 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -20,11 +20,12 @@ class ProcessorType(str, Enum): class TextSearchModel(): - def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k): + def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder self.cross_encoder = cross_encoder + self.filters = filters self.top_k = top_k diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 52ebc330..7ea6580c 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -2,6 +2,7 @@ import pathlib import sys from os.path import join +from collections import OrderedDict def is_none_or_empty(item): @@ -60,4 +61,21 @@ def load_model(model_name, model_dir, model_type, device:str=None): def is_pyinstaller_app(): "Returns true if the app is running from Native GUI created by PyInstaller" - return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') \ No newline at end of file + return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') + + +class LRU(OrderedDict): + def __init__(self, *args, capacity=128, **kwargs): + self.capacity = capacity + super().__init__(*args, **kwargs) + + def __getitem__(self, key): + value = super().__getitem__(key) + self.move_to_end(key) + return value + + def __setitem__(self, key, value): + super().__setitem__(key, value) + if len(self) > self.capacity: + oldest = next(iter(self)) + del self[oldest] diff --git a/tests/conftest.py b/tests/conftest.py index b70deb87..930ec734 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,9 @@ import pytest # Internal Packages from src.search_type import image_search, text_search +from src.utils.config import SearchType from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig from src.processor.org_mode.org_to_jsonl import org_to_jsonl -from src.utils import state @pytest.fixture(scope='session') @@ -46,7 +46,7 @@ def model_dir(search_config): batch_size = 10, use_xmp_metadata = False) - image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True) + image_search.setup(content_config.image, search_config.image, regenerate=False) # Generate Notes Embeddings from Test Notes content_config.org = TextContentConfig( @@ -55,7 +55,7 @@ def model_dir(search_config): compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), embeddings_file = model_dir.joinpath('note_embeddings.pt')) - text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, verbose=True) + text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) return model_dir diff --git a/tests/test_client.py b/tests/test_client.py index 38b98c1f..e9b632be 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,6 +8,7 @@ import pytest # Internal Packages from src.main import app +from src.utils.config import SearchType from src.utils.state import model, config from src.search_type import text_search, image_search from src.utils.rawconfig import ContentConfig, SearchConfig @@ -115,7 +116,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- def test_notes_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) user_query = "How to git install application?" # Act @@ -131,8 +132,8 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) - user_query = "How to git install application? +Emacs" + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + user_query = 'How to git install application? +"Emacs"' # Act response = client.get(f"/search?q={user_query}&n=1&t=org") @@ -147,8 +148,8 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) - user_query = "How to git install application? -clone" + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) + user_query = 'How to git install application? -"clone"' # Act response = client.get(f"/search?q={user_query}&n=1&t=org") diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 88e31c86..ddb1fcf0 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -18,37 +18,37 @@ def test_date_filter(): {'compiled': '', 'raw': 'Entry with date:1984-04-02'}] q_with_no_date_filter = 'head tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings) + ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_no_date_filter, entries.copy(), embeddings) assert ret_query == 'head tail' assert len(ret_emb) == 3 assert ret_entries == entries q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) + ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) assert ret_query == 'head tail' assert len(ret_emb) == 0 assert ret_entries == [] query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) assert ret_query == 'head tail' assert ret_entries == [entries[2]] assert len(ret_emb) == 1 query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) assert ret_query == 'head tail' assert ret_entries == [entries[1]] assert len(ret_emb) == 1 query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) assert ret_query == 'head tail' assert ret_entries == [entries[2]] assert len(ret_emb) == 1 query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) assert ret_query == 'head tail' assert ret_entries == [entries[1], entries[2]] assert len(ret_emb) == 2 diff --git a/tests/test_explicit_filter.py b/tests/test_explicit_filter.py new file mode 100644 index 00000000..5f34b0ac --- /dev/null +++ b/tests/test_explicit_filter.py @@ -0,0 +1,85 @@ +# External Packages +import torch + +# Application Packages +from src.search_filter.explicit_filter import ExplicitFilter +from src.utils.config import SearchType + + +def test_no_explicit_filter(tmp_path): + # Arrange + explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) + embeddings, entries = arrange_content() + q_with_no_filter = 'head tail' + + # Act + can_filter = explicit_filter.can_filter(q_with_no_filter) + ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == False + assert ret_query == 'head tail' + assert len(ret_emb) == 4 + assert ret_entries == entries + + +def test_explicit_exclude_filter(tmp_path): + # Arrange + explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) + embeddings, entries = arrange_content() + q_with_exclude_filter = 'head -"exclude_word" tail' + + # Act + can_filter = explicit_filter.can_filter(q_with_exclude_filter) + ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 2 + assert ret_entries == [entries[0], entries[2]] + + +def test_explicit_include_filter(tmp_path): + # Arrange + explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) + embeddings, entries = arrange_content() + query_with_include_filter = 'head +"include_word" tail' + + # Act + can_filter = explicit_filter.can_filter(query_with_include_filter) + ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 2 + assert ret_entries == [entries[2], entries[3]] + + +def test_explicit_include_and_exclude_filter(tmp_path): + # Arrange + explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) + embeddings, entries = arrange_content() + query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail' + + # Act + can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter) + ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) + + # Assert + assert can_filter == True + assert ret_query == 'head tail' + assert len(ret_emb) == 1 + assert ret_entries == [entries[2]] + + +def arrange_content(): + embeddings = torch.randn(4, 10) + entries = [ + {'compiled': '', 'raw': 'Minimal Entry'}, + {'compiled': '', 'raw': 'Entry with exclude_word'}, + {'compiled': '', 'raw': 'Entry with include_word'}, + {'compiled': '', 'raw': 'Entry with include_word and exclude_word'}] + + return embeddings, entries diff --git a/tests/test_helpers.py b/tests/test_helpers.py index d4f06e6d..c9b1cd75 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -28,3 +28,18 @@ def test_merge_dicts(): # do not override existing key in priority_dict with default dict assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1} + + +def test_lru_cache(): + # Test initializing cache + cache = helpers.LRU({'a': 1, 'b': 2}, capacity=2) + assert cache == {'a': 1, 'b': 2} + + # Test capacity overflow + cache['c'] = 3 + assert cache == {'b': 2, 'c': 3} + + # Test delete least recently used item from LRU cache on capacity overflow + cache['b'] # accessing 'b' makes it the most recently used item + cache['d'] = 4 # so 'c' is deleted from the cache instead of 'b' + assert cache == {'b': 2, 'd': 4} diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 39fed92e..84f16df5 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -1,5 +1,6 @@ # System Packages from pathlib import Path +from src.utils.config import SearchType # Internal Packages from src.utils.state import model @@ -13,7 +14,7 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): # Act # Regenerate notes embeddings during asymmetric setup - notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) + notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True) # Assert assert len(notes_model.entries) == 10 @@ -23,7 +24,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo # ---------------------------------------------------------------------------------------------------- def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) query = "How to git install application?" # Act @@ -46,7 +47,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC # ---------------------------------------------------------------------------------------------------- def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig): # Arrange - initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 @@ -59,11 +60,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") # regenerate notes jsonl, model embeddings and model to include entry from new file - regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) + regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True) # Act # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files - initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) # Assert assert len(regenerated_notes_model.entries) == 11