From c7de57b8ea0cf6562f746c670cf23b5ca0fe7b93 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 3 Sep 2022 16:01:54 +0300 Subject: [PATCH 01/13] Pre-compute entry word sets to improve explicit filter query performance --- src/configure.py | 8 ++-- src/router.py | 8 ++-- src/search_filter/date_filter.py | 13 ++++++- src/search_filter/explicit_filter.py | 56 +++++++++++++++++++++++----- src/search_type/text_search.py | 19 +++++++--- src/utils/config.py | 3 +- 6 files changed, 81 insertions(+), 26 deletions(-) diff --git a/src/configure.py b/src/configure.py index 0e1e3333..938062eb 100644 --- a/src/configure.py +++ b/src/configure.py @@ -40,22 +40,22 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, # Initialize Org Notes Search if (t == SearchType.Org or t == None) and config.content_type.org: # Extract Entries, Generate Notes Embeddings - model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate) + model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, search_type=SearchType.Org, regenerate=regenerate) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings - model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate) + model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, search_type=SearchType.Music, regenerate=regenerate) # Initialize Markdown Search if (t == SearchType.Markdown or t == None) and config.content_type.markdown: # Extract Entries, Generate Markdown Embeddings - model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate) + model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, search_type=SearchType.Markdown, regenerate=regenerate) # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: # Extract Entries, Generate Ledger Embeddings - model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate) + model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, search_type=SearchType.Ledger, regenerate=regenerate) # Initialize Image Search if (t == SearchType.Image or t == None) and config.content_type.image: diff --git a/src/router.py b/src/router.py index 412692ff..127623c6 100644 --- a/src/router.py +++ b/src/router.py @@ -65,7 +65,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Org or t == None) and state.model.orgmode_search: # query org-mode notes query_start = time.time() - hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r) query_end = time.time() # collate and return results @@ -76,7 +76,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Music or t == None) and state.model.music_search: # query music library query_start = time.time() - hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r) query_end = time.time() # collate and return results @@ -87,7 +87,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Markdown or t == None) and state.model.markdown_search: # query markdown files query_start = time.time() - hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r) query_end = time.time() # collate and return results @@ -98,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Ledger or t == None) and state.model.ledger_search: # query transactions query_start = time.time() - hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r) query_end = time.time() # collate and return results diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index dc70ca29..d91ebd83 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -17,12 +17,21 @@ class DateFilter: # - dt:"2 years ago" date_regex = r"dt([:><=]{1,2})\"(.*?)\"" + + def __init__(self, entry_key='raw'): + self.entry_key = entry_key + + + def load(*args, **kwargs): + pass + + def can_filter(self, raw_query): "Check if query contains date filters" return self.extract_date_range(raw_query) is not None - def filter(self, query, entries, embeddings, entry_key='raw'): + def apply(self, query, entries, embeddings): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query query_daterange = self.extract_date_range(query) @@ -39,7 +48,7 @@ class DateFilter: entries_to_include = set() for id, entry in enumerate(entries): # Extract dates from entry - for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]): + for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): # Convert date string in entry to unix timestamp try: date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index b7bb6754..2cf82d70 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -1,11 +1,46 @@ # Standard Packages import re +import time +import pickle # External Packages import torch +# Internal Packages +from src.utils.helpers import resolve_absolute_path +from src.utils.config import SearchType + class ExplicitFilter: + def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'): + self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") + self.entry_key = entry_key + self.search_type = search_type + + + def load(self, entries, regenerate=False): + if self.filter_file.exists() and not regenerate: + start = time.time() + with self.filter_file.open('rb') as f: + entries_by_word_set = pickle.load(f) + end = time.time() + print(f"Load {self.search_type} entries by word set from file: {end - start} seconds") + else: + start = time.time() + entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' + entries_by_word_set = [set(word.lower() + for word + in re.split(entry_splitter, entry[self.entry_key]) + if word != "") + for entry in entries] + with self.filter_file.open('wb') as f: + pickle.dump(entries_by_word_set, f) + end = time.time() + print(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") + + return entries_by_word_set + + def can_filter(self, raw_query): "Check if query contains explicit filters" # Extract explicit query portion with required, blocked words to filter from natural query @@ -15,26 +50,24 @@ class ExplicitFilter: return len(required_words) != 0 or len(blocked_words) != 0 - def filter(self, raw_query, entries, embeddings, entry_key='raw'): + def apply(self, raw_query, entries, embeddings): "Find entries containing required and not blocked words specified in query" # Separate natural query from explicit required, blocked words filters + start = time.time() query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + end = time.time() + print(f"Time to extract required, blocked words: {end - start} seconds") if len(required_words) == 0 and len(blocked_words) == 0: return query, entries, embeddings - # convert each entry to a set of words - # split on fullstop, comma, colon, tab, newline or any brackets - entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' - entries_by_word_set = [set(word.lower() - for word - in re.split(entry_splitter, entry[entry_key]) - if word != "") - for entry in entries] + # load or generate word set for each entry + entries_by_word_set = self.load(entries, regenerate=False) # track id of entries to exclude + start = time.time() entries_to_exclude = set() # mark entries that do not contain all required_words for exclusion @@ -48,10 +81,15 @@ class ExplicitFilter: for id, words_in_entry in enumerate(entries_by_word_set): if words_in_entry.intersection(blocked_words): entries_to_exclude.add(id) + end = time.time() + print(f"Mark entries to filter: {end - start} seconds") # delete entries (and their embeddings) marked for exclusion + start = time.time() for id in sorted(list(entries_to_exclude), reverse=True): del entries[id] embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) + end = time.time() + print(f"Remove entries to filter from embeddings: {end - start} seconds") return query, entries, embeddings diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index fe066033..4f83236e 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -8,11 +8,13 @@ from copy import deepcopy # External Packages import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util +from src.search_filter.date_filter import DateFilter +from src.search_filter.explicit_filter import ExplicitFilter # Internal Packages from src.utils import state from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model -from src.utils.config import TextSearchModel +from src.utils.config import SearchType, TextSearchModel from src.utils.rawconfig import TextSearchConfig, TextContentConfig from src.utils.jsonl import load_jsonl @@ -73,13 +75,13 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False): return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []): +def query(raw_query: str, model: TextSearchModel, rank_results=False): "Search for entries that answer the query" query = raw_query # Use deep copy of original embeddings, entries to filter if query contains filters start = time.time() - filters_in_query = [filter for filter in filters if filter.can_filter(query)] + filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] if filters_in_query: corpus_embeddings = deepcopy(model.corpus_embeddings) entries = deepcopy(model.entries) @@ -92,7 +94,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l # Filter query, entries and embeddings before semantic search start = time.time() for filter in filters_in_query: - query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings) + query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings) end = time.time() logger.debug(f"Filter Time: {end - start:.3f} seconds") @@ -163,7 +165,7 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel: +def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) @@ -180,7 +182,12 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon config.embeddings_file = resolve_absolute_path(config.embeddings_file) corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate) - return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k) + filter_directory = resolve_absolute_path(config.compressed_jsonl.parent) + filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)] + for filter in filters: + filter.load(entries, regenerate=regenerate) + + return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k) if __name__ == '__main__': diff --git a/src/utils/config.py b/src/utils/config.py index 6e69d8b4..a4de6b81 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -20,11 +20,12 @@ class ProcessorType(str, Enum): class TextSearchModel(): - def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k): + def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder self.cross_encoder = cross_encoder + self.filters = filters self.top_k = top_k From 30c3eb372a573de616eccc8acfbf1a3fade5841f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 3 Sep 2022 22:13:25 +0300 Subject: [PATCH 02/13] Update Tests to Configure Filters and Setup Text Search --- tests/conftest.py | 6 +++--- tests/test_client.py | 7 ++++--- tests/test_date_filter.py | 12 ++++++------ tests/test_text_search.py | 11 ++++++----- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b70deb87..930ec734 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,9 @@ import pytest # Internal Packages from src.search_type import image_search, text_search +from src.utils.config import SearchType from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig from src.processor.org_mode.org_to_jsonl import org_to_jsonl -from src.utils import state @pytest.fixture(scope='session') @@ -46,7 +46,7 @@ def model_dir(search_config): batch_size = 10, use_xmp_metadata = False) - image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True) + image_search.setup(content_config.image, search_config.image, regenerate=False) # Generate Notes Embeddings from Test Notes content_config.org = TextContentConfig( @@ -55,7 +55,7 @@ def model_dir(search_config): compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), embeddings_file = model_dir.joinpath('note_embeddings.pt')) - text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, verbose=True) + text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) return model_dir diff --git a/tests/test_client.py b/tests/test_client.py index 38b98c1f..a80b2fa1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,6 +8,7 @@ import pytest # Internal Packages from src.main import app +from src.utils.config import SearchType from src.utils.state import model, config from src.search_type import text_search, image_search from src.utils.rawconfig import ContentConfig, SearchConfig @@ -115,7 +116,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- def test_notes_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) user_query = "How to git install application?" # Act @@ -131,7 +132,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) user_query = "How to git install application? +Emacs" # Act @@ -147,7 +148,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) user_query = "How to git install application? -clone" # Act diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 88e31c86..ddb1fcf0 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -18,37 +18,37 @@ def test_date_filter(): {'compiled': '', 'raw': 'Entry with date:1984-04-02'}] q_with_no_date_filter = 'head tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings) + ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_no_date_filter, entries.copy(), embeddings) assert ret_query == 'head tail' assert len(ret_emb) == 3 assert ret_entries == entries q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) + ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) assert ret_query == 'head tail' assert len(ret_emb) == 0 assert ret_entries == [] query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) assert ret_query == 'head tail' assert ret_entries == [entries[2]] assert len(ret_emb) == 1 query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) assert ret_query == 'head tail' assert ret_entries == [entries[1]] assert len(ret_emb) == 1 query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) assert ret_query == 'head tail' assert ret_entries == [entries[2]] assert len(ret_emb) == 1 query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' - ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) + ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings) assert ret_query == 'head tail' assert ret_entries == [entries[1], entries[2]] assert len(ret_emb) == 2 diff --git a/tests/test_text_search.py b/tests/test_text_search.py index 39fed92e..84f16df5 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -1,5 +1,6 @@ # System Packages from pathlib import Path +from src.utils.config import SearchType # Internal Packages from src.utils.state import model @@ -13,7 +14,7 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): # Act # Regenerate notes embeddings during asymmetric setup - notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) + notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True) # Assert assert len(notes_model.entries) == 10 @@ -23,7 +24,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo # ---------------------------------------------------------------------------------------------------- def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): # Arrange - model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) query = "How to git install application?" # Act @@ -46,7 +47,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC # ---------------------------------------------------------------------------------------------------- def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig): # Arrange - initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.corpus_embeddings) == 10 @@ -59,11 +60,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") # regenerate notes jsonl, model embeddings and model to include entry from new file - regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) + regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True) # Act # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files - initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) + initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) # Assert assert len(regenerated_notes_model.entries) == 11 From ffb8e3988e2a8a185c9032ae826b3c94aaf906ea Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 3 Sep 2022 22:14:37 +0300 Subject: [PATCH 03/13] Use Python Logging Framework to Time Performance of Explicit Filter --- src/search_filter/explicit_filter.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index 2cf82d70..09580e4a 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -2,6 +2,7 @@ import re import time import pickle +import logging # External Packages import torch @@ -11,6 +12,9 @@ from src.utils.helpers import resolve_absolute_path from src.utils.config import SearchType +logger = logging.getLogger(__name__) + + class ExplicitFilter: def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'): self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") @@ -24,7 +28,7 @@ class ExplicitFilter: with self.filter_file.open('rb') as f: entries_by_word_set = pickle.load(f) end = time.time() - print(f"Load {self.search_type} entries by word set from file: {end - start} seconds") + logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds") else: start = time.time() entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' @@ -36,7 +40,7 @@ class ExplicitFilter: with self.filter_file.open('wb') as f: pickle.dump(entries_by_word_set, f) end = time.time() - print(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") + logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") return entries_by_word_set @@ -58,7 +62,7 @@ class ExplicitFilter: required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) end = time.time() - print(f"Time to extract required, blocked words: {end - start} seconds") + logger.debug(f"Time to extract required, blocked words: {end - start} seconds") if len(required_words) == 0 and len(blocked_words) == 0: return query, entries, embeddings @@ -82,7 +86,7 @@ class ExplicitFilter: if words_in_entry.intersection(blocked_words): entries_to_exclude.add(id) end = time.time() - print(f"Mark entries to filter: {end - start} seconds") + logger.debug(f"Mark entries to filter: {end - start} seconds") # delete entries (and their embeddings) marked for exclusion start = time.time() @@ -90,6 +94,6 @@ class ExplicitFilter: del entries[id] embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) end = time.time() - print(f"Remove entries to filter from embeddings: {end - start} seconds") + logger.debug(f"Remove entries to filter from embeddings: {end - start} seconds") return query, entries, embeddings From b7d259b1ec8d770a58fc1076af9af92d62511b93 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 3 Sep 2022 23:00:09 +0300 Subject: [PATCH 04/13] Test Explicit Include, Exclude Filters --- tests/test_explicit_filter.py | 77 +++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 tests/test_explicit_filter.py diff --git a/tests/test_explicit_filter.py b/tests/test_explicit_filter.py new file mode 100644 index 00000000..f3b88659 --- /dev/null +++ b/tests/test_explicit_filter.py @@ -0,0 +1,77 @@ +# External Packages +import torch + +# Application Packages +from src.search_filter.explicit_filter import ExplicitFilter +from src.utils.config import SearchType + + +def test_no_explicit_filter(tmp_path): + # Arrange + explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) + embeddings, entries = arrange_content() + q_with_no_filter = 'head tail' + + # Act + ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings) + + # Assert + assert ret_query == 'head tail' + assert len(ret_emb) == 4 + assert ret_entries == entries + + +def test_explicit_exclude_filter(tmp_path): + # Arrange + explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) + embeddings, entries = arrange_content() + q_with_exclude_filter = 'head -exclude_word tail' + + # Act + ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) + + # Assert + assert ret_query == 'head tail' + assert len(ret_emb) == 2 + assert ret_entries == [entries[0], entries[2]] + + +def test_explicit_include_filter(tmp_path): + # Arrange + explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) + embeddings, entries = arrange_content() + query_with_include_filter = 'head +include_word tail' + + # Act + ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings) + + # Assert + assert ret_query == 'head tail' + assert len(ret_emb) == 2 + assert ret_entries == [entries[2], entries[3]] + + +def test_explicit_include_and_exclude_filter(tmp_path): + # Arrange + explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) + embeddings, entries = arrange_content() + query_with_include_and_exclude_filter = 'head +include_word -exclude_word tail' + + # Act + ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) + + # Assert + assert ret_query == 'head tail' + assert len(ret_emb) == 1 + assert ret_entries == [entries[2]] + + +def arrange_content(): + embeddings = torch.randn(4, 10) + entries = [ + {'compiled': '', 'raw': 'Minimal Entry'}, + {'compiled': '', 'raw': 'Entry with exclude_word'}, + {'compiled': '', 'raw': 'Entry with include_word'}, + {'compiled': '', 'raw': 'Entry with include_word and exclude_word'}] + + return embeddings, entries From 546fad570d97638b64a198a06b3fa7457e3ddac0 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 3 Sep 2022 23:33:52 +0300 Subject: [PATCH 05/13] Use regex to extract include, exclude filter words from query --- src/search_filter/explicit_filter.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index 09580e4a..797c007d 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -16,6 +16,10 @@ logger = logging.getLogger(__name__) class ExplicitFilter: + # Filter Regex + required_regex = r'\+([^\s]+) ?' + blocked_regex = r'\-([^\s]+) ?' + def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'): self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") self.entry_key = entry_key @@ -58,11 +62,13 @@ class ExplicitFilter: "Find entries containing required and not blocked words specified in query" # Separate natural query from explicit required, blocked words filters start = time.time() - query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) - required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) - blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + + required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)]) + blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)]) + query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)) + end = time.time() - logger.debug(f"Time to extract required, blocked words: {end - start} seconds") + logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") if len(required_words) == 0 and len(blocked_words) == 0: return query, entries, embeddings @@ -86,7 +92,7 @@ class ExplicitFilter: if words_in_entry.intersection(blocked_words): entries_to_exclude.add(id) end = time.time() - logger.debug(f"Mark entries to filter: {end - start} seconds") + logger.debug(f"Mark entries not satisfying filter: {end - start} seconds") # delete entries (and their embeddings) marked for exclusion start = time.time() @@ -94,6 +100,6 @@ class ExplicitFilter: del entries[id] embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) end = time.time() - logger.debug(f"Remove entries to filter from embeddings: {end - start} seconds") + logger.debug(f"Delete entries not satisfying filter: {end - start} seconds") return query, entries, embeddings From 858d86075b8126fe26b4f8de725ed462ddc6fdee Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 3 Sep 2022 23:47:28 +0300 Subject: [PATCH 06/13] Use regexes to check if any explicit filters in query. Test can_filter --- src/search_filter/explicit_filter.py | 4 ++-- tests/test_explicit_filter.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index 797c007d..9d043a4d 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -52,8 +52,8 @@ class ExplicitFilter: def can_filter(self, raw_query): "Check if query contains explicit filters" # Extract explicit query portion with required, blocked words to filter from natural query - required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) - blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + required_words = re.findall(self.required_regex, raw_query) + blocked_words = re.findall(self.blocked_regex, raw_query) return len(required_words) != 0 or len(blocked_words) != 0 diff --git a/tests/test_explicit_filter.py b/tests/test_explicit_filter.py index f3b88659..9d4c022a 100644 --- a/tests/test_explicit_filter.py +++ b/tests/test_explicit_filter.py @@ -13,9 +13,11 @@ def test_no_explicit_filter(tmp_path): q_with_no_filter = 'head tail' # Act + can_filter = explicit_filter.can_filter(q_with_no_filter) ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings) # Assert + assert can_filter == False assert ret_query == 'head tail' assert len(ret_emb) == 4 assert ret_entries == entries @@ -28,9 +30,11 @@ def test_explicit_exclude_filter(tmp_path): q_with_exclude_filter = 'head -exclude_word tail' # Act + can_filter = explicit_filter.can_filter(q_with_exclude_filter) ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) # Assert + assert can_filter == True assert ret_query == 'head tail' assert len(ret_emb) == 2 assert ret_entries == [entries[0], entries[2]] @@ -43,9 +47,11 @@ def test_explicit_include_filter(tmp_path): query_with_include_filter = 'head +include_word tail' # Act + can_filter = explicit_filter.can_filter(query_with_include_filter) ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings) # Assert + assert can_filter == True assert ret_query == 'head tail' assert len(ret_emb) == 2 assert ret_entries == [entries[2], entries[3]] @@ -58,9 +64,11 @@ def test_explicit_include_and_exclude_filter(tmp_path): query_with_include_and_exclude_filter = 'head +include_word -exclude_word tail' # Act + can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter) ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) # Assert + assert can_filter == True assert ret_query == 'head tail' assert len(ret_emb) == 1 assert ret_entries == [entries[2]] From 8d9f507df38a3ef6b6a08d5e63f7f8a4d8cc1d80 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 4 Sep 2022 00:37:37 +0300 Subject: [PATCH 07/13] Load entries_by_word_set from file only once on first load of explicit filter --- src/search_filter/explicit_filter.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index 9d043a4d..a719bcdd 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -24,29 +24,30 @@ class ExplicitFilter: self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") self.entry_key = entry_key self.search_type = search_type + self.entries_by_word_set = None def load(self, entries, regenerate=False): if self.filter_file.exists() and not regenerate: start = time.time() with self.filter_file.open('rb') as f: - entries_by_word_set = pickle.load(f) + self.entries_by_word_set = pickle.load(f) end = time.time() logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds") else: start = time.time() entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' - entries_by_word_set = [set(word.lower() + self.entries_by_word_set = [set(word.lower() for word in re.split(entry_splitter, entry[self.entry_key]) if word != "") for entry in entries] with self.filter_file.open('wb') as f: - pickle.dump(entries_by_word_set, f) + pickle.dump(self.entries_by_word_set, f) end = time.time() logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") - return entries_by_word_set + return self.entries_by_word_set def can_filter(self, raw_query): @@ -73,8 +74,8 @@ class ExplicitFilter: if len(required_words) == 0 and len(blocked_words) == 0: return query, entries, embeddings - # load or generate word set for each entry - entries_by_word_set = self.load(entries, regenerate=False) + if not self.entries_by_word_set: + self.load(entries, regenerate=False) # track id of entries to exclude start = time.time() @@ -82,13 +83,13 @@ class ExplicitFilter: # mark entries that do not contain all required_words for exclusion if len(required_words) > 0: - for id, words_in_entry in enumerate(entries_by_word_set): + for id, words_in_entry in enumerate(self.entries_by_word_set): if not required_words.issubset(words_in_entry): entries_to_exclude.add(id) # mark entries that contain any blocked_words for exclusion if len(blocked_words) > 0: - for id, words_in_entry in enumerate(entries_by_word_set): + for id, words_in_entry in enumerate(self.entries_by_word_set): if words_in_entry.intersection(blocked_words): entries_to_exclude.add(id) end = time.time() From cdcee89ae52347d35678ddb072d023fbfcb91b68 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 4 Sep 2022 02:12:56 +0300 Subject: [PATCH 08/13] Wrap words in quotes to trigger explicit filter from query - Do not run the more expensive explicit filter until the word to be filtered is completed by user. This requires an end sequence marker to identify end of explicit word filter to trigger filtering - Space isn't a good enough delimiter as the explicit filter could be at the end of the query in which case no space --- src/search_filter/explicit_filter.py | 4 ++-- tests/test_client.py | 4 ++-- tests/test_explicit_filter.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index a719bcdd..e3b5bb9f 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -17,8 +17,8 @@ logger = logging.getLogger(__name__) class ExplicitFilter: # Filter Regex - required_regex = r'\+([^\s]+) ?' - blocked_regex = r'\-([^\s]+) ?' + required_regex = r'\+"(\w+)" ?' + blocked_regex = r'\-"(\w+)" ?' def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'): self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") diff --git a/tests/test_client.py b/tests/test_client.py index a80b2fa1..e9b632be 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -133,7 +133,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) - user_query = "How to git install application? +Emacs" + user_query = 'How to git install application? +"Emacs"' # Act response = client.get(f"/search?q={user_query}&n=1&t=org") @@ -149,7 +149,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False) - user_query = "How to git install application? -clone" + user_query = 'How to git install application? -"clone"' # Act response = client.get(f"/search?q={user_query}&n=1&t=org") diff --git a/tests/test_explicit_filter.py b/tests/test_explicit_filter.py index 9d4c022a..5f34b0ac 100644 --- a/tests/test_explicit_filter.py +++ b/tests/test_explicit_filter.py @@ -27,7 +27,7 @@ def test_explicit_exclude_filter(tmp_path): # Arrange explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) embeddings, entries = arrange_content() - q_with_exclude_filter = 'head -exclude_word tail' + q_with_exclude_filter = 'head -"exclude_word" tail' # Act can_filter = explicit_filter.can_filter(q_with_exclude_filter) @@ -44,7 +44,7 @@ def test_explicit_include_filter(tmp_path): # Arrange explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) embeddings, entries = arrange_content() - query_with_include_filter = 'head +include_word tail' + query_with_include_filter = 'head +"include_word" tail' # Act can_filter = explicit_filter.can_filter(query_with_include_filter) @@ -61,7 +61,7 @@ def test_explicit_include_and_exclude_filter(tmp_path): # Arrange explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) embeddings, entries = arrange_content() - query_with_include_and_exclude_filter = 'head +include_word -exclude_word tail' + query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail' # Act can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter) From 3308e68edf9992318ec309b210de91830a9f2ad3 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 4 Sep 2022 02:21:10 +0300 Subject: [PATCH 09/13] Cache explicitly filtered entries, embeddings by required, blocked words --- src/search_filter/explicit_filter.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index e3b5bb9f..2707155b 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -25,6 +25,7 @@ class ExplicitFilter: self.entry_key = entry_key self.search_type = search_type self.entries_by_word_set = None + self.cache = {} def load(self, entries, regenerate=False): @@ -36,6 +37,7 @@ class ExplicitFilter: logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds") else: start = time.time() + self.cache = {} # Clear cache on (re-)generating entries_by_word_set entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' self.entries_by_word_set = [set(word.lower() for word @@ -72,6 +74,13 @@ class ExplicitFilter: logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") if len(required_words) == 0 and len(blocked_words) == 0: + return query, raw_entries, raw_embeddings + + # Return item from cache if exists + cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) + if cache_key in self.cache: + logger.info(f"Explicit filter results from cache") + entries, embeddings = self.cache[cache_key] return query, entries, embeddings if not self.entries_by_word_set: @@ -103,4 +112,7 @@ class ExplicitFilter: end = time.time() logger.debug(f"Delete entries not satisfying filter: {end - start} seconds") + # Cache results + self.cache[cache_key] = entries, embeddings + return query, entries, embeddings From 28d3dc1434db07b70beac206595e7cfadab56adc Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 4 Sep 2022 02:22:42 +0300 Subject: [PATCH 10/13] Deep copy entries, embeddings in filters. Defer till actual filtering - Only the filter knows when entries, embeddings are to be manipulated. So move the responsibility to deep copy before manipulating entries, embeddings to the filters - Create deep copy in filters. Avoids creating deep copy of entries, embeddings when filter results are being loaded from cache etc --- src/search_filter/date_filter.py | 9 +++++++-- src/search_filter/explicit_filter.py | 10 +++++++++- src/search_type/text_search.py | 16 ++-------------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index d91ebd83..cab47cbb 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -3,6 +3,7 @@ import re from datetime import timedelta, datetime from dateutil.relativedelta import relativedelta, MO from math import inf +from copy import deepcopy # External Packages import torch @@ -31,19 +32,23 @@ class DateFilter: return self.extract_date_range(raw_query) is not None - def apply(self, query, entries, embeddings): + def apply(self, query, raw_entries, raw_embeddings): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query query_daterange = self.extract_date_range(query) # if no date in query, return all entries if query_daterange is None: - return query, entries, embeddings + return query, raw_entries, raw_embeddings # remove date range filter from query query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces + # deep copy original embeddings, entries before filtering + embeddings= deepcopy(raw_embeddings) + entries = deepcopy(raw_entries) + # find entries containing any dates that fall with date range specified in query entries_to_include = set() for id, entry in enumerate(entries): diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index 2707155b..e715e8b6 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -3,6 +3,7 @@ import re import time import pickle import logging +from copy import deepcopy # External Packages import torch @@ -61,7 +62,7 @@ class ExplicitFilter: return len(required_words) != 0 or len(blocked_words) != 0 - def apply(self, raw_query, entries, embeddings): + def apply(self, raw_query, raw_entries, raw_embeddings): "Find entries containing required and not blocked words specified in query" # Separate natural query from explicit required, blocked words filters start = time.time() @@ -83,6 +84,13 @@ class ExplicitFilter: entries, embeddings = self.cache[cache_key] return query, entries, embeddings + # deep copy original embeddings, entries before filtering + start = time.time() + embeddings= deepcopy(raw_embeddings) + entries = deepcopy(raw_entries) + end = time.time() + logger.debug(f"Create copy of embeddings, entries for manipulation: {end - start:.3f} seconds") + if not self.entries_by_word_set: self.load(entries, regenerate=False) diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 4f83236e..742ff5ed 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -3,7 +3,6 @@ import argparse import pathlib import logging import time -from copy import deepcopy # External Packages import torch @@ -77,22 +76,11 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False): def query(raw_query: str, model: TextSearchModel, rank_results=False): "Search for entries that answer the query" - query = raw_query - - # Use deep copy of original embeddings, entries to filter if query contains filters - start = time.time() - filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] - if filters_in_query: - corpus_embeddings = deepcopy(model.corpus_embeddings) - entries = deepcopy(model.entries) - else: - corpus_embeddings = model.corpus_embeddings - entries = model.entries - end = time.time() - logger.debug(f"Copy Time: {end - start:.3f} seconds") + query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings # Filter query, entries and embeddings before semantic search start = time.time() + filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] for filter in filters_in_query: query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings) end = time.time() From 191a656ed7c0c8442f208abe1d30520c305e947f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 4 Sep 2022 15:09:09 +0300 Subject: [PATCH 11/13] Use word to entry map, list comprehension to speed up explicit filter - Code Changes - Use list comprehension and `torch.index_select' methods - to speed selection of entries, embedding tensors satisfying filter - avoid deep copy of entries, embeddings - avoid updating existing lists (of entries, embeddings) - Use word to entry map and set operations to mark entries satisfying inclusion, exclusion filters - Results - Speed up explicit filtering by two orders of magnitude - Improve consistency of speed up across inclusion and exclusion filtering --- src/search_filter/explicit_filter.py | 70 +++++++++++++--------------- 1 file changed, 33 insertions(+), 37 deletions(-) diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index e715e8b6..6f64ede5 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -25,7 +25,7 @@ class ExplicitFilter: self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") self.entry_key = entry_key self.search_type = search_type - self.entries_by_word_set = None + self.word_to_entry_index = dict() self.cache = {} @@ -33,24 +33,28 @@ class ExplicitFilter: if self.filter_file.exists() and not regenerate: start = time.time() with self.filter_file.open('rb') as f: - self.entries_by_word_set = pickle.load(f) + self.word_to_entry_index = pickle.load(f) end = time.time() logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds") else: start = time.time() self.cache = {} # Clear cache on (re-)generating entries_by_word_set entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' - self.entries_by_word_set = [set(word.lower() - for word - in re.split(entry_splitter, entry[self.entry_key]) - if word != "") - for entry in entries] + # Create map of words to entries they exist in + for entry_index, entry in enumerate(entries): + for word in re.split(entry_splitter, entry[self.entry_key].lower()): + if word == '': + continue + if word not in self.word_to_entry_index: + self.word_to_entry_index[word] = set() + self.word_to_entry_index[word].add(entry_index) + with self.filter_file.open('wb') as f: - pickle.dump(self.entries_by_word_set, f) + pickle.dump(self.word_to_entry_index, f) end = time.time() logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") - return self.entries_by_word_set + return self.word_to_entry_index def can_filter(self, raw_query): @@ -69,7 +73,7 @@ class ExplicitFilter: required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)]) blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)]) - query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)) + query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip() end = time.time() logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") @@ -84,41 +88,33 @@ class ExplicitFilter: entries, embeddings = self.cache[cache_key] return query, entries, embeddings - # deep copy original embeddings, entries before filtering + if not self.word_to_entry_index: + self.load(raw_entries, regenerate=False) + start = time.time() - embeddings= deepcopy(raw_embeddings) - entries = deepcopy(raw_entries) - end = time.time() - logger.debug(f"Create copy of embeddings, entries for manipulation: {end - start:.3f} seconds") - if not self.entries_by_word_set: - self.load(entries, regenerate=False) - - # track id of entries to exclude - start = time.time() - entries_to_exclude = set() - - # mark entries that do not contain all required_words for exclusion + # mark entries that contain all required_words for inclusion + entries_with_all_required_words = set(range(len(raw_entries))) if len(required_words) > 0: - for id, words_in_entry in enumerate(self.entries_by_word_set): - if not required_words.issubset(words_in_entry): - entries_to_exclude.add(id) + entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words]) # mark entries that contain any blocked_words for exclusion + entries_with_any_blocked_words = set() if len(blocked_words) > 0: - for id, words_in_entry in enumerate(self.entries_by_word_set): - if words_in_entry.intersection(blocked_words): - entries_to_exclude.add(id) - end = time.time() - logger.debug(f"Mark entries not satisfying filter: {end - start} seconds") + entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words]) - # delete entries (and their embeddings) marked for exclusion - start = time.time() - for id in sorted(list(entries_to_exclude), reverse=True): - del entries[id] - embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) end = time.time() - logger.debug(f"Delete entries not satisfying filter: {end - start} seconds") + logger.debug(f"Mark entries satisfying filter: {end - start} seconds") + + # get entries (and their embeddings) satisfying inclusion and exclusion filters + start = time.time() + + included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words + entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices] + embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) + + end = time.time() + logger.debug(f"Keep entries satisfying filter: {end - start} seconds") # Cache results self.cache[cache_key] = entries, embeddings From 8f3326c8d4e058d44c0af4beaa965ce048971b8b Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 4 Sep 2022 16:31:46 +0300 Subject: [PATCH 12/13] Create LRU helper class for caching --- src/utils/helpers.py | 20 +++++++++++++++++++- tests/test_helpers.py | 15 +++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 52ebc330..7ea6580c 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -2,6 +2,7 @@ import pathlib import sys from os.path import join +from collections import OrderedDict def is_none_or_empty(item): @@ -60,4 +61,21 @@ def load_model(model_name, model_dir, model_type, device:str=None): def is_pyinstaller_app(): "Returns true if the app is running from Native GUI created by PyInstaller" - return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') \ No newline at end of file + return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') + + +class LRU(OrderedDict): + def __init__(self, *args, capacity=128, **kwargs): + self.capacity = capacity + super().__init__(*args, **kwargs) + + def __getitem__(self, key): + value = super().__getitem__(key) + self.move_to_end(key) + return value + + def __setitem__(self, key, value): + super().__setitem__(key, value) + if len(self) > self.capacity: + oldest = next(iter(self)) + del self[oldest] diff --git a/tests/test_helpers.py b/tests/test_helpers.py index d4f06e6d..c9b1cd75 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -28,3 +28,18 @@ def test_merge_dicts(): # do not override existing key in priority_dict with default dict assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1} + + +def test_lru_cache(): + # Test initializing cache + cache = helpers.LRU({'a': 1, 'b': 2}, capacity=2) + assert cache == {'a': 1, 'b': 2} + + # Test capacity overflow + cache['c'] = 3 + assert cache == {'b': 2, 'c': 3} + + # Test delete least recently used item from LRU cache on capacity overflow + cache['b'] # accessing 'b' makes it the most recently used item + cache['d'] = 4 # so 'c' is deleted from the cache instead of 'b' + assert cache == {'b': 2, 'd': 4} From 60878625214b96114cb5a4f747071b6079e017e0 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 4 Sep 2022 16:42:28 +0300 Subject: [PATCH 13/13] Use LRU helper class for explicit filter cache --- src/search_filter/explicit_filter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index 6f64ede5..7a26f830 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -3,13 +3,12 @@ import re import time import pickle import logging -from copy import deepcopy # External Packages import torch # Internal Packages -from src.utils.helpers import resolve_absolute_path +from src.utils.helpers import LRU, resolve_absolute_path from src.utils.config import SearchType @@ -26,7 +25,7 @@ class ExplicitFilter: self.entry_key = entry_key self.search_type = search_type self.word_to_entry_index = dict() - self.cache = {} + self.cache = LRU() def load(self, entries, regenerate=False):