diff --git a/src/configure.py b/src/configure.py index 0e1e3333..938062eb 100644 --- a/src/configure.py +++ b/src/configure.py @@ -40,22 +40,22 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, # Initialize Org Notes Search if (t == SearchType.Org or t == None) and config.content_type.org: # Extract Entries, Generate Notes Embeddings - model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate) + model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, search_type=SearchType.Org, regenerate=regenerate) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: # Extract Entries, Generate Music Embeddings - model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate) + model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, search_type=SearchType.Music, regenerate=regenerate) # Initialize Markdown Search if (t == SearchType.Markdown or t == None) and config.content_type.markdown: # Extract Entries, Generate Markdown Embeddings - model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate) + model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, search_type=SearchType.Markdown, regenerate=regenerate) # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: # Extract Entries, Generate Ledger Embeddings - model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate) + model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, search_type=SearchType.Ledger, regenerate=regenerate) # Initialize Image Search if (t == SearchType.Image or t == None) and config.content_type.image: diff --git a/src/router.py b/src/router.py index 412692ff..127623c6 100644 --- a/src/router.py +++ b/src/router.py @@ -65,7 +65,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Org or t == None) and state.model.orgmode_search: # query org-mode notes query_start = time.time() - hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r) query_end = time.time() # collate and return results @@ -76,7 +76,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Music or t == None) and state.model.music_search: # query music library query_start = time.time() - hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r) query_end = time.time() # collate and return results @@ -87,7 +87,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Markdown or t == None) and state.model.markdown_search: # query markdown files query_start = time.time() - hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r) query_end = time.time() # collate and return results @@ -98,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Ledger or t == None) and state.model.ledger_search: # query transactions query_start = time.time() - hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) + hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r) query_end = time.time() # collate and return results diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index dc70ca29..d91ebd83 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -17,12 +17,21 @@ class DateFilter: # - dt:"2 years ago" date_regex = r"dt([:><=]{1,2})\"(.*?)\"" + + def __init__(self, entry_key='raw'): + self.entry_key = entry_key + + + def load(*args, **kwargs): + pass + + def can_filter(self, raw_query): "Check if query contains date filters" return self.extract_date_range(raw_query) is not None - def filter(self, query, entries, embeddings, entry_key='raw'): + def apply(self, query, entries, embeddings): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query query_daterange = self.extract_date_range(query) @@ -39,7 +48,7 @@ class DateFilter: entries_to_include = set() for id, entry in enumerate(entries): # Extract dates from entry - for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]): + for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]): # Convert date string in entry to unix timestamp try: date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index b7bb6754..2cf82d70 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -1,11 +1,46 @@ # Standard Packages import re +import time +import pickle # External Packages import torch +# Internal Packages +from src.utils.helpers import resolve_absolute_path +from src.utils.config import SearchType + class ExplicitFilter: + def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'): + self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") + self.entry_key = entry_key + self.search_type = search_type + + + def load(self, entries, regenerate=False): + if self.filter_file.exists() and not regenerate: + start = time.time() + with self.filter_file.open('rb') as f: + entries_by_word_set = pickle.load(f) + end = time.time() + print(f"Load {self.search_type} entries by word set from file: {end - start} seconds") + else: + start = time.time() + entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' + entries_by_word_set = [set(word.lower() + for word + in re.split(entry_splitter, entry[self.entry_key]) + if word != "") + for entry in entries] + with self.filter_file.open('wb') as f: + pickle.dump(entries_by_word_set, f) + end = time.time() + print(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") + + return entries_by_word_set + + def can_filter(self, raw_query): "Check if query contains explicit filters" # Extract explicit query portion with required, blocked words to filter from natural query @@ -15,26 +50,24 @@ class ExplicitFilter: return len(required_words) != 0 or len(blocked_words) != 0 - def filter(self, raw_query, entries, embeddings, entry_key='raw'): + def apply(self, raw_query, entries, embeddings): "Find entries containing required and not blocked words specified in query" # Separate natural query from explicit required, blocked words filters + start = time.time() query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + end = time.time() + print(f"Time to extract required, blocked words: {end - start} seconds") if len(required_words) == 0 and len(blocked_words) == 0: return query, entries, embeddings - # convert each entry to a set of words - # split on fullstop, comma, colon, tab, newline or any brackets - entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' - entries_by_word_set = [set(word.lower() - for word - in re.split(entry_splitter, entry[entry_key]) - if word != "") - for entry in entries] + # load or generate word set for each entry + entries_by_word_set = self.load(entries, regenerate=False) # track id of entries to exclude + start = time.time() entries_to_exclude = set() # mark entries that do not contain all required_words for exclusion @@ -48,10 +81,15 @@ class ExplicitFilter: for id, words_in_entry in enumerate(entries_by_word_set): if words_in_entry.intersection(blocked_words): entries_to_exclude.add(id) + end = time.time() + print(f"Mark entries to filter: {end - start} seconds") # delete entries (and their embeddings) marked for exclusion + start = time.time() for id in sorted(list(entries_to_exclude), reverse=True): del entries[id] embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) + end = time.time() + print(f"Remove entries to filter from embeddings: {end - start} seconds") return query, entries, embeddings diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index fe066033..4f83236e 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -8,11 +8,13 @@ from copy import deepcopy # External Packages import torch from sentence_transformers import SentenceTransformer, CrossEncoder, util +from src.search_filter.date_filter import DateFilter +from src.search_filter.explicit_filter import ExplicitFilter # Internal Packages from src.utils import state from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model -from src.utils.config import TextSearchModel +from src.utils.config import SearchType, TextSearchModel from src.utils.rawconfig import TextSearchConfig, TextContentConfig from src.utils.jsonl import load_jsonl @@ -73,13 +75,13 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False): return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []): +def query(raw_query: str, model: TextSearchModel, rank_results=False): "Search for entries that answer the query" query = raw_query # Use deep copy of original embeddings, entries to filter if query contains filters start = time.time() - filters_in_query = [filter for filter in filters if filter.can_filter(query)] + filters_in_query = [filter for filter in model.filters if filter.can_filter(query)] if filters_in_query: corpus_embeddings = deepcopy(model.corpus_embeddings) entries = deepcopy(model.entries) @@ -92,7 +94,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: l # Filter query, entries and embeddings before semantic search start = time.time() for filter in filters_in_query: - query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings) + query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings) end = time.time() logger.debug(f"Filter Time: {end - start:.3f} seconds") @@ -163,7 +165,7 @@ def collate_results(hits, entries, count=5): in hits[0:count]] -def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel: +def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel: # Initialize Model bi_encoder, cross_encoder, top_k = initialize_model(search_config) @@ -180,7 +182,12 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon config.embeddings_file = resolve_absolute_path(config.embeddings_file) corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate) - return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k) + filter_directory = resolve_absolute_path(config.compressed_jsonl.parent) + filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)] + for filter in filters: + filter.load(entries, regenerate=regenerate) + + return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k) if __name__ == '__main__': diff --git a/src/utils/config.py b/src/utils/config.py index 6e69d8b4..a4de6b81 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -20,11 +20,12 @@ class ProcessorType(str, Enum): class TextSearchModel(): - def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k): + def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k): self.entries = entries self.corpus_embeddings = corpus_embeddings self.bi_encoder = bi_encoder self.cross_encoder = cross_encoder + self.filters = filters self.top_k = top_k