diff --git a/src/processor/ledger/beancount_to_jsonl.py b/src/processor/ledger/beancount_to_jsonl.py index 856dab57..74b1ed33 100644 --- a/src/processor/ledger/beancount_to_jsonl.py +++ b/src/processor/ledger/beancount_to_jsonl.py @@ -6,7 +6,7 @@ import time # Internal Packages from src.processor.text_to_jsonl import TextToJsonl -from src.utils.helpers import get_absolute_path, is_none_or_empty +from src.utils.helpers import get_absolute_path, is_none_or_empty, timer from src.utils.constants import empty_escape_sequences from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.rawconfig import Entry @@ -30,38 +30,30 @@ class BeancountToJsonl(TextToJsonl): beancount_files = BeancountToJsonl.get_beancount_files(beancount_files, beancount_file_filter) # Extract Entries from specified Beancount files - start = time.time() - current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files)) - end = time.time() - logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds") + with timer("Parse transactions from Beancount files into dictionaries", logger): + current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files)) # Split entries by max tokens supported by model - start = time.time() - current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) - end = time.time() - logger.debug(f"Split entries by max token size supported by model: {end - start} seconds") + with timer("Split entries by max token size supported by model", logger): + current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) # Identify, mark and merge any new entries with previous entries - start = time.time() - if not previous_entries: - entries_with_ids = list(enumerate(current_entries)) - else: - entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) - end = time.time() - logger.debug(f"Identify new or updated transaction: {end - start} seconds") + with timer("Identify new or updated transaction", logger): + if not previous_entries: + entries_with_ids = list(enumerate(current_entries)) + else: + entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) - # Process Each Entry from All Notes Files - start = time.time() - entries = list(map(lambda entry: entry[1], entries_with_ids)) - jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries) + with timer("Write transactions to JSONL file", logger): + # Process Each Entry from All Notes Files + entries = list(map(lambda entry: entry[1], entries_with_ids)) + jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries) - # Compress JSONL formatted Data - if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file) - elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file) - end = time.time() - logger.debug(f"Write transactions to JSONL file: {end - start} seconds") + # Compress JSONL formatted Data + if output_file.suffix == ".gz": + compress_jsonl_data(jsonl_data, output_file) + elif output_file.suffix == ".jsonl": + dump_jsonl(jsonl_data, output_file) return entries_with_ids diff --git a/src/processor/markdown/markdown_to_jsonl.py b/src/processor/markdown/markdown_to_jsonl.py index 82d860b8..189de84e 100644 --- a/src/processor/markdown/markdown_to_jsonl.py +++ b/src/processor/markdown/markdown_to_jsonl.py @@ -6,7 +6,7 @@ import time # Internal Packages from src.processor.text_to_jsonl import TextToJsonl -from src.utils.helpers import get_absolute_path, is_none_or_empty +from src.utils.helpers import get_absolute_path, is_none_or_empty, timer from src.utils.constants import empty_escape_sequences from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.rawconfig import Entry @@ -30,38 +30,30 @@ class MarkdownToJsonl(TextToJsonl): markdown_files = MarkdownToJsonl.get_markdown_files(markdown_files, markdown_file_filter) # Extract Entries from specified Markdown files - start = time.time() - current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files)) - end = time.time() - logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds") + with timer("Parse entries from Markdown files into dictionaries", logger): + current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files)) # Split entries by max tokens supported by model - start = time.time() - current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) - end = time.time() - logger.debug(f"Split entries by max token size supported by model: {end - start} seconds") + with timer("Split entries by max token size supported by model", logger): + current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) # Identify, mark and merge any new entries with previous entries - start = time.time() - if not previous_entries: - entries_with_ids = list(enumerate(current_entries)) - else: - entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) - end = time.time() - logger.debug(f"Identify new or updated entries: {end - start} seconds") + with timer("Identify new or updated entries", logger): + if not previous_entries: + entries_with_ids = list(enumerate(current_entries)) + else: + entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) - # Process Each Entry from All Notes Files - start = time.time() - entries = list(map(lambda entry: entry[1], entries_with_ids)) - jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries) + with timer("Write markdown entries to JSONL file", logger): + # Process Each Entry from All Notes Files + entries = list(map(lambda entry: entry[1], entries_with_ids)) + jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries) - # Compress JSONL formatted Data - if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file) - elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file) - end = time.time() - logger.debug(f"Write markdown entries to JSONL file: {end - start} seconds") + # Compress JSONL formatted Data + if output_file.suffix == ".gz": + compress_jsonl_data(jsonl_data, output_file) + elif output_file.suffix == ".jsonl": + dump_jsonl(jsonl_data, output_file) return entries_with_ids diff --git a/src/processor/org_mode/org_to_jsonl.py b/src/processor/org_mode/org_to_jsonl.py index 5ad68b77..f2c301cd 100644 --- a/src/processor/org_mode/org_to_jsonl.py +++ b/src/processor/org_mode/org_to_jsonl.py @@ -7,7 +7,7 @@ from typing import Iterable # Internal Packages from src.processor.org_mode import orgnode from src.processor.text_to_jsonl import TextToJsonl -from src.utils.helpers import get_absolute_path, is_none_or_empty +from src.utils.helpers import get_absolute_path, is_none_or_empty, timer from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.rawconfig import Entry from src.utils import state @@ -29,24 +29,18 @@ class OrgToJsonl(TextToJsonl): exit(1) # Get Org Files to Process - start = time.time() - org_files = OrgToJsonl.get_org_files(org_files, org_file_filter) + with timer("Get org files to process", logger): + org_files = OrgToJsonl.get_org_files(org_files, org_file_filter) # Extract Entries from specified Org files - start = time.time() - entry_nodes, file_to_entries = self.extract_org_entries(org_files) - end = time.time() - logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds") + with timer("Parse entries from org files into OrgNode objects", logger): + entry_nodes, file_to_entries = self.extract_org_entries(org_files) - start = time.time() - current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries) - end = time.time() - logger.debug(f"Convert OrgNodes into list of entries: {end - start} seconds") + with timer("Convert OrgNodes into list of entries", logger): + current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries) - start = time.time() - current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) - end = time.time() - logger.debug(f"Split entries by max token size supported by model: {end - start} seconds") + with timer("Split entries by max token size supported by model", logger): + current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) # Identify, mark and merge any new entries with previous entries if not previous_entries: @@ -55,17 +49,15 @@ class OrgToJsonl(TextToJsonl): entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) # Process Each Entry from All Notes Files - start = time.time() - entries = map(lambda entry: entry[1], entries_with_ids) - jsonl_data = self.convert_org_entries_to_jsonl(entries) + with timer("Write org entries to JSONL file", logger): + entries = map(lambda entry: entry[1], entries_with_ids) + jsonl_data = self.convert_org_entries_to_jsonl(entries) - # Compress JSONL formatted Data - if output_file.suffix == ".gz": - compress_jsonl_data(jsonl_data, output_file) - elif output_file.suffix == ".jsonl": - dump_jsonl(jsonl_data, output_file) - end = time.time() - logger.debug(f"Write org entries to JSONL file: {end - start} seconds") + # Compress JSONL formatted Data + if output_file.suffix == ".gz": + compress_jsonl_data(jsonl_data, output_file) + elif output_file.suffix == ".jsonl": + dump_jsonl(jsonl_data, output_file) return entries_with_ids diff --git a/src/processor/text_to_jsonl.py b/src/processor/text_to_jsonl.py index 4d88a612..3c198784 100644 --- a/src/processor/text_to_jsonl.py +++ b/src/processor/text_to_jsonl.py @@ -4,6 +4,7 @@ import hashlib import time import logging from typing import Callable +from src.utils.helpers import timer # Internal Packages from src.utils.rawconfig import Entry, TextContentConfig @@ -40,35 +41,31 @@ class TextToJsonl(ABC): def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]: # Hash all current and previous entries to identify new entries - start = time.time() - current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries)) - previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries)) - end = time.time() - logger.debug(f"Hash previous, current entries: {end - start} seconds") + with timer("Hash previous, current entries", logger): + current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries)) + previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries)) - start = time.time() - hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) - hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries)) + with timer("Identify, Mark, Combine new, existing entries", logger): + hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) + hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries)) - # All entries that did not exist in the previous set are to be added - new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes) - # All entries that exist in both current and previous sets are kept - existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes) + # All entries that did not exist in the previous set are to be added + new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes) + # All entries that exist in both current and previous sets are kept + existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes) - # Mark new entries with -1 id to flag for later embeddings generation - new_entries = [ - (-1, hash_to_current_entries[entry_hash]) - for entry_hash in new_entry_hashes - ] - # Set id of existing entries to their previous ids to reuse their existing encoded embeddings - existing_entries = [ - (previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash]) - for entry_hash in existing_entry_hashes - ] + # Mark new entries with -1 id to flag for later embeddings generation + new_entries = [ + (-1, hash_to_current_entries[entry_hash]) + for entry_hash in new_entry_hashes + ] + # Set id of existing entries to their previous ids to reuse their existing encoded embeddings + existing_entries = [ + (previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash]) + for entry_hash in existing_entry_hashes + ] - existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0]) - entries_with_ids = existing_entries_sorted + new_entries - end = time.time() - logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds") + existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0]) + entries_with_ids = existing_entries_sorted + new_entries return entries_with_ids \ No newline at end of file diff --git a/src/routers/api.py b/src/routers/api.py index b78de882..0da30eb3 100644 --- a/src/routers/api.py +++ b/src/routers/api.py @@ -10,6 +10,7 @@ from fastapi import APIRouter # Internal Packages from src.configure import configure_processor, configure_search from src.search_type import image_search, text_search +from src.utils.helpers import timer from src.utils.rawconfig import FullConfig, SearchResponse from src.utils.config import SearchType from src.utils import state, constants @@ -47,7 +48,6 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti # initialize variables user_query = q.strip() results_count = n - query_start, query_end, collate_start, collate_end = None, None, None, None # return cached results, if available query_cache_key = f'{user_query}-{n}-{t}-{r}' @@ -57,73 +57,58 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti if (t == SearchType.Org or t == None) and state.model.orgmode_search: # query org-mode notes - query_start = time.time() - hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r) - query_end = time.time() + with timer("Query took", logger): + hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r) # collate and return results - collate_start = time.time() - results = text_search.collate_results(hits, entries, results_count) - collate_end = time.time() + with timer("Collating results took", logger): + results = text_search.collate_results(hits, entries, results_count) if (t == SearchType.Music or t == None) and state.model.music_search: # query music library - query_start = time.time() - hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r) - query_end = time.time() + with timer("Query took", logger): + hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r) # collate and return results - collate_start = time.time() - results = text_search.collate_results(hits, entries, results_count) - collate_end = time.time() + with timer("Collating results took", logger): + results = text_search.collate_results(hits, entries, results_count) if (t == SearchType.Markdown or t == None) and state.model.markdown_search: # query markdown files - query_start = time.time() - hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r) - query_end = time.time() + with timer("Query took", logger): + hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r) # collate and return results - collate_start = time.time() - results = text_search.collate_results(hits, entries, results_count) - collate_end = time.time() + with timer("Collating results took", logger): + results = text_search.collate_results(hits, entries, results_count) if (t == SearchType.Ledger or t == None) and state.model.ledger_search: # query transactions - query_start = time.time() - hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r) - query_end = time.time() + with timer("Query took", logger): + hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r) # collate and return results - collate_start = time.time() - results = text_search.collate_results(hits, entries, results_count) - collate_end = time.time() + with timer("Collating results took", logger): + results = text_search.collate_results(hits, entries, results_count) if (t == SearchType.Image or t == None) and state.model.image_search: # query images - query_start = time.time() - hits = image_search.query(user_query, results_count, state.model.image_search) - output_directory = constants.web_directory / 'images' - query_end = time.time() + with timer("Query took", logger): + hits = image_search.query(user_query, results_count, state.model.image_search) + output_directory = constants.web_directory / 'images' # collate and return results - collate_start = time.time() - results = image_search.collate_results( - hits, - image_names=state.model.image_search.image_names, - output_directory=output_directory, - image_files_url='/static/images', - count=results_count) - collate_end = time.time() + with timer("Collating results took", logger): + results = image_search.collate_results( + hits, + image_names=state.model.image_search.image_names, + output_directory=output_directory, + image_files_url='/static/images', + count=results_count) # Cache results state.query_cache[query_cache_key] = results - if query_start and query_end: - logger.debug(f"Query took {query_end - query_start:.3f} seconds") - if collate_start and collate_end: - logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds") - return results diff --git a/src/search_filter/date_filter.py b/src/search_filter/date_filter.py index 4f1242c9..bcacc190 100644 --- a/src/search_filter/date_filter.py +++ b/src/search_filter/date_filter.py @@ -12,7 +12,7 @@ import dateparser as dtparse # Internal Packages from src.search_filter.base_filter import BaseFilter -from src.utils.helpers import LRU +from src.utils.helpers import LRU, timer logger = logging.getLogger(__name__) @@ -34,19 +34,16 @@ class DateFilter(BaseFilter): def load(self, entries, *args, **kwargs): - start = time.time() - for id, entry in enumerate(entries): - # Extract dates from entry - for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)): - # Convert date string in entry to unix timestamp - try: - date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() - except ValueError: - continue - self.date_to_entry_ids[date_in_entry].add(id) - end = time.time() - logger.debug(f"Created date filter index: {end - start} seconds") - + with timer("Created date filter index", logger): + for id, entry in enumerate(entries): + # Extract dates from entry + for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)): + # Convert date string in entry to unix timestamp + try: + date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() + except ValueError: + continue + self.date_to_entry_ids[date_in_entry].add(id) def can_filter(self, raw_query): "Check if query contains date filters" @@ -56,10 +53,8 @@ class DateFilter(BaseFilter): def apply(self, query, entries): "Find entries containing any dates that fall within date range specified in query" # extract date range specified in date filter of query - start = time.time() - query_daterange = self.extract_date_range(query) - end = time.time() - logger.debug(f"Extract date range to filter from query: {end - start} seconds") + with timer("Extract date range to filter from query", logger): + query_daterange = self.extract_date_range(query) # if no date in query, return all entries if query_daterange is None: @@ -80,14 +75,12 @@ class DateFilter(BaseFilter): self.load(entries) # find entries containing any dates that fall with date range specified in query - start = time.time() - entries_to_include = set() - for date_in_entry in self.date_to_entry_ids.keys(): - # Check if date in entry is within date range specified in query - if query_daterange[0] <= date_in_entry < query_daterange[1]: - entries_to_include |= self.date_to_entry_ids[date_in_entry] - end = time.time() - logger.debug(f"Mark entries satisfying filter: {end - start} seconds") + with timer("Mark entries satisfying filter", logger): + entries_to_include = set() + for date_in_entry in self.date_to_entry_ids.keys(): + # Check if date in entry is within date range specified in query + if query_daterange[0] <= date_in_entry < query_daterange[1]: + entries_to_include |= self.date_to_entry_ids[date_in_entry] # cache results self.cache[cache_key] = entries_to_include diff --git a/src/search_filter/file_filter.py b/src/search_filter/file_filter.py index 95635207..35fb078a 100644 --- a/src/search_filter/file_filter.py +++ b/src/search_filter/file_filter.py @@ -7,7 +7,7 @@ from collections import defaultdict # Internal Packages from src.search_filter.base_filter import BaseFilter -from src.utils.helpers import LRU +from src.utils.helpers import LRU, timer logger = logging.getLogger(__name__) @@ -22,32 +22,28 @@ class FileFilter(BaseFilter): self.cache = LRU() def load(self, entries, *args, **kwargs): - start = time.time() - for id, entry in enumerate(entries): - self.file_to_entry_map[getattr(entry, self.entry_key)].add(id) - end = time.time() - logger.debug(f"Created file filter index: {end - start} seconds") + with timer("Created file filter index", logger): + for id, entry in enumerate(entries): + self.file_to_entry_map[getattr(entry, self.entry_key)].add(id) def can_filter(self, raw_query): return re.search(self.file_filter_regex, raw_query) is not None def apply(self, query, entries): # Extract file filters from raw query - start = time.time() - raw_files_to_search = re.findall(self.file_filter_regex, query) - if not raw_files_to_search: - return query, set(range(len(entries))) + with timer("Extract files_to_search from query", logger): + raw_files_to_search = re.findall(self.file_filter_regex, query) + if not raw_files_to_search: + return query, set(range(len(entries))) - # Convert simple file filters with no path separator into regex - # e.g. "file:notes.org" -> "file:.*notes.org" - files_to_search = [] - for file in sorted(raw_files_to_search): - if '/' not in file and '\\' not in file and '*' not in file: - files_to_search += [f'*{file}'] - else: - files_to_search += [file] - end = time.time() - logger.debug(f"Extract files_to_search from query: {end - start} seconds") + # Convert simple file filters with no path separator into regex + # e.g. "file:notes.org" -> "file:.*notes.org" + files_to_search = [] + for file in sorted(raw_files_to_search): + if '/' not in file and '\\' not in file and '*' not in file: + files_to_search += [f'*{file}'] + else: + files_to_search += [file] # Return item from cache if exists query = re.sub(self.file_filter_regex, '', query).strip() @@ -61,17 +57,13 @@ class FileFilter(BaseFilter): self.load(entries, regenerate=False) # Mark entries that contain any blocked_words for exclusion - start = time.time() - - included_entry_indices = set.union(*[self.file_to_entry_map[entry_file] - for entry_file in self.file_to_entry_map.keys() - for search_file in files_to_search - if fnmatch.fnmatch(entry_file, search_file)], set()) - if not included_entry_indices: - return query, {} - - end = time.time() - logger.debug(f"Mark entries satisfying filter: {end - start} seconds") + with timer("Mark entries satisfying filter", logger): + included_entry_indices = set.union(*[self.file_to_entry_map[entry_file] + for entry_file in self.file_to_entry_map.keys() + for search_file in files_to_search + if fnmatch.fnmatch(entry_file, search_file)], set()) + if not included_entry_indices: + return query, {} # Cache results self.cache[cache_key] = included_entry_indices diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py index e1cfea6a..684847b3 100644 --- a/src/search_filter/word_filter.py +++ b/src/search_filter/word_filter.py @@ -6,7 +6,7 @@ from collections import defaultdict # Internal Packages from src.search_filter.base_filter import BaseFilter -from src.utils.helpers import LRU +from src.utils.helpers import LRU, timer logger = logging.getLogger(__name__) @@ -24,17 +24,15 @@ class WordFilter(BaseFilter): def load(self, entries, *args, **kwargs): - start = time.time() - self.cache = {} # Clear cache on filter (re-)load - entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' - # Create map of words to entries they exist in - for entry_index, entry in enumerate(entries): - for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()): - if word == '': - continue - self.word_to_entry_index[word].add(entry_index) - end = time.time() - logger.debug(f"Created word filter index: {end - start} seconds") + with timer("Created word filter index", logger): + self.cache = {} # Clear cache on filter (re-)load + entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' + # Create map of words to entries they exist in + for entry_index, entry in enumerate(entries): + for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()): + if word == '': + continue + self.word_to_entry_index[word].add(entry_index) return self.word_to_entry_index @@ -50,14 +48,10 @@ class WordFilter(BaseFilter): def apply(self, query, entries): "Find entries containing required and not blocked words specified in query" # Separate natural query from required, blocked words filters - start = time.time() - - required_words = set([word.lower() for word in re.findall(self.required_regex, query)]) - blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)]) - query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip() - - end = time.time() - logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") + with timer("Extract required, blocked filters from query", logger): + required_words = set([word.lower() for word in re.findall(self.required_regex, query)]) + blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)]) + query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip() if len(required_words) == 0 and len(blocked_words) == 0: return query, set(range(len(entries))) @@ -72,20 +66,16 @@ class WordFilter(BaseFilter): if not self.word_to_entry_index: self.load(entries, regenerate=False) - start = time.time() - # mark entries that contain all required_words for inclusion - entries_with_all_required_words = set(range(len(entries))) - if len(required_words) > 0: - entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words]) + with timer("Mark entries satisfying filter", logger): + entries_with_all_required_words = set(range(len(entries))) + if len(required_words) > 0: + entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words]) - # mark entries that contain any blocked_words for exclusion - entries_with_any_blocked_words = set() - if len(blocked_words) > 0: - entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words]) - - end = time.time() - logger.debug(f"Mark entries satisfying filter: {end - start} seconds") + # mark entries that contain any blocked_words for exclusion + entries_with_any_blocked_words = set() + if len(blocked_words) > 0: + entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words]) # get entries satisfying inclusion and exclusion filters included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index 5eecef31..08b4ec5b 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -13,7 +13,7 @@ from tqdm import trange import torch # Internal Packages -from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model +from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer from src.utils.config import ImageSearchModel from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse @@ -147,27 +147,21 @@ def query(raw_query, count, model: ImageSearchModel): logger.info(f"Find Images by Text: {query}") # Now we encode the query (which can either be an image or a text string) - start = time.time() - query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False) - end = time.time() - logger.debug(f"Query Encode Time: {end - start:.3f} seconds") + with timer("Query Encode Time", logger): + query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False) # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. - start = time.time() - image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']} - for result - in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]} - end = time.time() - logger.debug(f"Search Time: {end - start:.3f} seconds") + with timer("Search Time", logger): + image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']} + for result + in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]} # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. if model.image_metadata_embeddings: - start = time.time() - metadata_hits = {result['corpus_id']: result['score'] - for result - in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]} - end = time.time() - logger.debug(f"Metadata Search Time: {end - start:.3f} seconds") + with timer("Metadata Search Time", logger): + metadata_hits = {result['corpus_id']: result['score'] + for result + in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]} # Sum metadata, image scores of the highest ranked images for corpus_id, score in metadata_hits.items(): diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 43d6d187..8a758a20 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -12,7 +12,7 @@ from src.search_filter.base_filter import BaseFilter # Internal Packages from src.utils import state -from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model +from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer from src.utils.config import TextSearchModel from src.utils.models import BaseEncoder from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry @@ -96,6 +96,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> # Filter query, entries and embeddings before semantic search query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters) + # If no entries left after filtering, return empty results if entries is None or len(entries) == 0: return [], [] @@ -105,17 +106,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> return hits, entries # Encode the query using the bi-encoder - start = time.time() - question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) - question_embedding = util.normalize_embeddings(question_embedding) - end = time.time() - logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}") + with timer("Query Encode Time", logger, state.device): + question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) + question_embedding = util.normalize_embeddings(question_embedding) # Find relevant entries for the query - start = time.time() - hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] - end = time.time() - logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}") + with timer("Search Time", logger, state.device): + hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] # Score all retrieved entries using the cross-encoder if rank_results: @@ -170,36 +167,29 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]: '''Filter query, entries and embeddings before semantic search''' - start_filter = time.time() - included_entry_indices = set(range(len(entries))) - filters_in_query = [filter for filter in filters if filter.can_filter(query)] - for filter in filters_in_query: - query, included_entry_indices_by_filter = filter.apply(query, entries) - included_entry_indices.intersection_update(included_entry_indices_by_filter) - # Get entries (and associated embeddings) satisfying all filters - if not included_entry_indices: - return '', [], torch.tensor([], device=state.device) - else: - start = time.time() - entries = [entries[id] for id in included_entry_indices] - corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)) - end = time.time() - logger.debug(f"Keep entries satisfying all filters: {end - start} seconds") + with timer("Total Filter Time", logger, state.device): + included_entry_indices = set(range(len(entries))) + filters_in_query = [filter for filter in filters if filter.can_filter(query)] + for filter in filters_in_query: + query, included_entry_indices_by_filter = filter.apply(query, entries) + included_entry_indices.intersection_update(included_entry_indices_by_filter) - end_filter = time.time() - logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}") + # Get entries (and associated embeddings) satisfying all filters + if not included_entry_indices: + return '', [], torch.tensor([], device=state.device) + else: + entries = [entries[id] for id in included_entry_indices] + corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)) return query, entries, corpus_embeddings def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]: '''Score all retrieved entries using the cross-encoder''' - start = time.time() - cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits] - cross_scores = cross_encoder.predict(cross_inp) - end = time.time() - logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}") + with timer("Cross-Encoder Predict Time", logger, state.device): + cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits] + cross_scores = cross_encoder.predict(cross_inp) # Store cross-encoder scores in results dictionary for ranking for idx in range(len(cross_scores)): @@ -210,12 +200,10 @@ def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[E def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]: '''Order results by cross-encoder score followed by bi-encoder score''' - start = time.time() - hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score - if rank_results: - hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score - end = time.time() - logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}") + with timer("Rank Time", logger, state.device): + hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score + if rank_results: + hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score return hits @@ -223,11 +211,12 @@ def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]: '''Deduplicate entries by raw entry text before showing to users Compiled entries are split by max tokens supported by ML models. This can result in duplicate hits, entries shown to user.''' - start = time.time() - seen, original_hits_count = set(), len(hits) - hits = [hit for hit in hits - if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)] - duplicate_hits = original_hits_count - len(hits) - end = time.time() - logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates") + + with timer("Deduplication Time", logger, state.device): + seen, original_hits_count = set(), len(hits) + hits = [hit for hit in hits + if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)] + duplicate_hits = original_hits_count - len(hits) + + logger.debug(f"Removed {duplicate_hits} duplicates") return hits diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 1bac6e81..5ecbfef7 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -2,10 +2,12 @@ from __future__ import annotations # to avoid quoting type hints import logging import sys +import torch from collections import OrderedDict from importlib import import_module from os.path import join from pathlib import Path +from time import perf_counter from typing import Optional, Union, TYPE_CHECKING if TYPE_CHECKING: @@ -81,6 +83,25 @@ def get_class_by_name(name: str) -> object: return getattr(import_module(module_name), class_name) +class timer: + '''Context manager to log time taken for a block of code to run''' + def __init__(self, message: str, logger: logging.Logger, device: torch.device = None): + self.message = message + self.logger = logger + self.device = device + + def __enter__(self): + self.start = perf_counter() + return self + + def __exit__(self, *_): + elapsed = perf_counter() - self.start + if self.device is None: + self.logger.debug(f"{self.message}: {elapsed:.3f} seconds") + else: + self.logger.debug(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}") + + class LRU(OrderedDict): def __init__(self, *args, capacity=128, **kwargs): self.capacity = capacity