Create and use a context manager to time code

Use the timer context manager in all places where code was being timed

- Benefits
  - Deduplicate timing code scattered across codebase.
  - Provides single place to manage perf timing code
  - Use consistent timing log patterns
This commit is contained in:
Debanjum Singh Solanky
2023-01-09 19:43:19 -03:00
parent 93f39dbd43
commit aa22d83172
11 changed files with 235 additions and 298 deletions

View File

@@ -6,7 +6,7 @@ import time
# Internal Packages # Internal Packages
from src.processor.text_to_jsonl import TextToJsonl from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty from src.utils.helpers import get_absolute_path, is_none_or_empty, timer
from src.utils.constants import empty_escape_sequences from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import Entry from src.utils.rawconfig import Entry
@@ -30,38 +30,30 @@ class BeancountToJsonl(TextToJsonl):
beancount_files = BeancountToJsonl.get_beancount_files(beancount_files, beancount_file_filter) beancount_files = BeancountToJsonl.get_beancount_files(beancount_files, beancount_file_filter)
# Extract Entries from specified Beancount files # Extract Entries from specified Beancount files
start = time.time() with timer("Parse transactions from Beancount files into dictionaries", logger):
current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files)) current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files))
end = time.time()
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
# Split entries by max tokens supported by model # Split entries by max tokens supported by model
start = time.time() with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
end = time.time()
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
start = time.time() with timer("Identify new or updated transaction", logger):
if not previous_entries: if not previous_entries:
entries_with_ids = list(enumerate(current_entries)) entries_with_ids = list(enumerate(current_entries))
else: else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time()
logger.debug(f"Identify new or updated transaction: {end - start} seconds")
# Process Each Entry from All Notes Files with timer("Write transactions to JSONL file", logger):
start = time.time() # Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids)) entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries) jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries)
# Compress JSONL formatted Data # Compress JSONL formatted Data
if output_file.suffix == ".gz": if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file) compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl": elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file) dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write transactions to JSONL file: {end - start} seconds")
return entries_with_ids return entries_with_ids

View File

@@ -6,7 +6,7 @@ import time
# Internal Packages # Internal Packages
from src.processor.text_to_jsonl import TextToJsonl from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty from src.utils.helpers import get_absolute_path, is_none_or_empty, timer
from src.utils.constants import empty_escape_sequences from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import Entry from src.utils.rawconfig import Entry
@@ -30,38 +30,30 @@ class MarkdownToJsonl(TextToJsonl):
markdown_files = MarkdownToJsonl.get_markdown_files(markdown_files, markdown_file_filter) markdown_files = MarkdownToJsonl.get_markdown_files(markdown_files, markdown_file_filter)
# Extract Entries from specified Markdown files # Extract Entries from specified Markdown files
start = time.time() with timer("Parse entries from Markdown files into dictionaries", logger):
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files)) current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files))
end = time.time()
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
# Split entries by max tokens supported by model # Split entries by max tokens supported by model
start = time.time() with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
end = time.time()
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
start = time.time() with timer("Identify new or updated entries", logger):
if not previous_entries: if not previous_entries:
entries_with_ids = list(enumerate(current_entries)) entries_with_ids = list(enumerate(current_entries))
else: else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time()
logger.debug(f"Identify new or updated entries: {end - start} seconds")
# Process Each Entry from All Notes Files with timer("Write markdown entries to JSONL file", logger):
start = time.time() # Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids)) entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries) jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
# Compress JSONL formatted Data # Compress JSONL formatted Data
if output_file.suffix == ".gz": if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file) compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl": elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file) dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write markdown entries to JSONL file: {end - start} seconds")
return entries_with_ids return entries_with_ids

View File

@@ -7,7 +7,7 @@ from typing import Iterable
# Internal Packages # Internal Packages
from src.processor.org_mode import orgnode from src.processor.org_mode import orgnode
from src.processor.text_to_jsonl import TextToJsonl from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty from src.utils.helpers import get_absolute_path, is_none_or_empty, timer
from src.utils.jsonl import dump_jsonl, compress_jsonl_data from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import Entry from src.utils.rawconfig import Entry
from src.utils import state from src.utils import state
@@ -29,24 +29,18 @@ class OrgToJsonl(TextToJsonl):
exit(1) exit(1)
# Get Org Files to Process # Get Org Files to Process
start = time.time() with timer("Get org files to process", logger):
org_files = OrgToJsonl.get_org_files(org_files, org_file_filter) org_files = OrgToJsonl.get_org_files(org_files, org_file_filter)
# Extract Entries from specified Org files # Extract Entries from specified Org files
start = time.time() with timer("Parse entries from org files into OrgNode objects", logger):
entry_nodes, file_to_entries = self.extract_org_entries(org_files) entry_nodes, file_to_entries = self.extract_org_entries(org_files)
end = time.time()
logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds")
start = time.time() with timer("Convert OrgNodes into list of entries", logger):
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries) current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
end = time.time()
logger.debug(f"Convert OrgNodes into list of entries: {end - start} seconds")
start = time.time() with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256) current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
end = time.time()
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
# Identify, mark and merge any new entries with previous entries # Identify, mark and merge any new entries with previous entries
if not previous_entries: if not previous_entries:
@@ -55,17 +49,15 @@ class OrgToJsonl(TextToJsonl):
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger) entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
# Process Each Entry from All Notes Files # Process Each Entry from All Notes Files
start = time.time() with timer("Write org entries to JSONL file", logger):
entries = map(lambda entry: entry[1], entries_with_ids) entries = map(lambda entry: entry[1], entries_with_ids)
jsonl_data = self.convert_org_entries_to_jsonl(entries) jsonl_data = self.convert_org_entries_to_jsonl(entries)
# Compress JSONL formatted Data # Compress JSONL formatted Data
if output_file.suffix == ".gz": if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file) compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl": elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file) dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write org entries to JSONL file: {end - start} seconds")
return entries_with_ids return entries_with_ids

View File

@@ -4,6 +4,7 @@ import hashlib
import time import time
import logging import logging
from typing import Callable from typing import Callable
from src.utils.helpers import timer
# Internal Packages # Internal Packages
from src.utils.rawconfig import Entry, TextContentConfig from src.utils.rawconfig import Entry, TextContentConfig
@@ -40,35 +41,31 @@ class TextToJsonl(ABC):
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]: def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
# Hash all current and previous entries to identify new entries # Hash all current and previous entries to identify new entries
start = time.time() with timer("Hash previous, current entries", logger):
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries)) current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries)) previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries))
end = time.time()
logger.debug(f"Hash previous, current entries: {end - start} seconds")
start = time.time() with timer("Identify, Mark, Combine new, existing entries", logger):
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries)) hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries)) hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
# All entries that did not exist in the previous set are to be added # All entries that did not exist in the previous set are to be added
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes) new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
# All entries that exist in both current and previous sets are kept # All entries that exist in both current and previous sets are kept
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes) existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# Mark new entries with -1 id to flag for later embeddings generation # Mark new entries with -1 id to flag for later embeddings generation
new_entries = [ new_entries = [
(-1, hash_to_current_entries[entry_hash]) (-1, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes for entry_hash in new_entry_hashes
] ]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings # Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [ existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash]) (previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
for entry_hash in existing_entry_hashes for entry_hash in existing_entry_hashes
] ]
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0]) existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries entries_with_ids = existing_entries_sorted + new_entries
end = time.time()
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
return entries_with_ids return entries_with_ids

View File

@@ -10,6 +10,7 @@ from fastapi import APIRouter
# Internal Packages # Internal Packages
from src.configure import configure_processor, configure_search from src.configure import configure_processor, configure_search
from src.search_type import image_search, text_search from src.search_type import image_search, text_search
from src.utils.helpers import timer
from src.utils.rawconfig import FullConfig, SearchResponse from src.utils.rawconfig import FullConfig, SearchResponse
from src.utils.config import SearchType from src.utils.config import SearchType
from src.utils import state, constants from src.utils import state, constants
@@ -47,7 +48,6 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
# initialize variables # initialize variables
user_query = q.strip() user_query = q.strip()
results_count = n results_count = n
query_start, query_end, collate_start, collate_end = None, None, None, None
# return cached results, if available # return cached results, if available
query_cache_key = f'{user_query}-{n}-{t}-{r}' query_cache_key = f'{user_query}-{n}-{t}-{r}'
@@ -57,73 +57,58 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Org or t == None) and state.model.orgmode_search: if (t == SearchType.Org or t == None) and state.model.orgmode_search:
# query org-mode notes # query org-mode notes
query_start = time.time() with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r) hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
query_end = time.time()
# collate and return results # collate and return results
collate_start = time.time() with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count) results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Music or t == None) and state.model.music_search: if (t == SearchType.Music or t == None) and state.model.music_search:
# query music library # query music library
query_start = time.time() with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r) hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
query_end = time.time()
# collate and return results # collate and return results
collate_start = time.time() with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count) results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Markdown or t == None) and state.model.markdown_search: if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
# query markdown files # query markdown files
query_start = time.time() with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r) hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
query_end = time.time()
# collate and return results # collate and return results
collate_start = time.time() with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count) results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Ledger or t == None) and state.model.ledger_search: if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions # query transactions
query_start = time.time() with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r) hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
query_end = time.time()
# collate and return results # collate and return results
collate_start = time.time() with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count) results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
if (t == SearchType.Image or t == None) and state.model.image_search: if (t == SearchType.Image or t == None) and state.model.image_search:
# query images # query images
query_start = time.time() with timer("Query took", logger):
hits = image_search.query(user_query, results_count, state.model.image_search) hits = image_search.query(user_query, results_count, state.model.image_search)
output_directory = constants.web_directory / 'images' output_directory = constants.web_directory / 'images'
query_end = time.time()
# collate and return results # collate and return results
collate_start = time.time() with timer("Collating results took", logger):
results = image_search.collate_results( results = image_search.collate_results(
hits, hits,
image_names=state.model.image_search.image_names, image_names=state.model.image_search.image_names,
output_directory=output_directory, output_directory=output_directory,
image_files_url='/static/images', image_files_url='/static/images',
count=results_count) count=results_count)
collate_end = time.time()
# Cache results # Cache results
state.query_cache[query_cache_key] = results state.query_cache[query_cache_key] = results
if query_start and query_end:
logger.debug(f"Query took {query_end - query_start:.3f} seconds")
if collate_start and collate_end:
logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
return results return results

View File

@@ -12,7 +12,7 @@ import dateparser as dtparse
# Internal Packages # Internal Packages
from src.search_filter.base_filter import BaseFilter from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU from src.utils.helpers import LRU, timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -34,19 +34,16 @@ class DateFilter(BaseFilter):
def load(self, entries, *args, **kwargs): def load(self, entries, *args, **kwargs):
start = time.time() with timer("Created date filter index", logger):
for id, entry in enumerate(entries): for id, entry in enumerate(entries):
# Extract dates from entry # Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)): for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
# Convert date string in entry to unix timestamp # Convert date string in entry to unix timestamp
try: try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
except ValueError: except ValueError:
continue continue
self.date_to_entry_ids[date_in_entry].add(id) self.date_to_entry_ids[date_in_entry].add(id)
end = time.time()
logger.debug(f"Created date filter index: {end - start} seconds")
def can_filter(self, raw_query): def can_filter(self, raw_query):
"Check if query contains date filters" "Check if query contains date filters"
@@ -56,10 +53,8 @@ class DateFilter(BaseFilter):
def apply(self, query, entries): def apply(self, query, entries):
"Find entries containing any dates that fall within date range specified in query" "Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query # extract date range specified in date filter of query
start = time.time() with timer("Extract date range to filter from query", logger):
query_daterange = self.extract_date_range(query) query_daterange = self.extract_date_range(query)
end = time.time()
logger.debug(f"Extract date range to filter from query: {end - start} seconds")
# if no date in query, return all entries # if no date in query, return all entries
if query_daterange is None: if query_daterange is None:
@@ -80,14 +75,12 @@ class DateFilter(BaseFilter):
self.load(entries) self.load(entries)
# find entries containing any dates that fall with date range specified in query # find entries containing any dates that fall with date range specified in query
start = time.time() with timer("Mark entries satisfying filter", logger):
entries_to_include = set() entries_to_include = set()
for date_in_entry in self.date_to_entry_ids.keys(): for date_in_entry in self.date_to_entry_ids.keys():
# Check if date in entry is within date range specified in query # Check if date in entry is within date range specified in query
if query_daterange[0] <= date_in_entry < query_daterange[1]: if query_daterange[0] <= date_in_entry < query_daterange[1]:
entries_to_include |= self.date_to_entry_ids[date_in_entry] entries_to_include |= self.date_to_entry_ids[date_in_entry]
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# cache results # cache results
self.cache[cache_key] = entries_to_include self.cache[cache_key] = entries_to_include

View File

@@ -7,7 +7,7 @@ from collections import defaultdict
# Internal Packages # Internal Packages
from src.search_filter.base_filter import BaseFilter from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU from src.utils.helpers import LRU, timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,32 +22,28 @@ class FileFilter(BaseFilter):
self.cache = LRU() self.cache = LRU()
def load(self, entries, *args, **kwargs): def load(self, entries, *args, **kwargs):
start = time.time() with timer("Created file filter index", logger):
for id, entry in enumerate(entries): for id, entry in enumerate(entries):
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id) self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
end = time.time()
logger.debug(f"Created file filter index: {end - start} seconds")
def can_filter(self, raw_query): def can_filter(self, raw_query):
return re.search(self.file_filter_regex, raw_query) is not None return re.search(self.file_filter_regex, raw_query) is not None
def apply(self, query, entries): def apply(self, query, entries):
# Extract file filters from raw query # Extract file filters from raw query
start = time.time() with timer("Extract files_to_search from query", logger):
raw_files_to_search = re.findall(self.file_filter_regex, query) raw_files_to_search = re.findall(self.file_filter_regex, query)
if not raw_files_to_search: if not raw_files_to_search:
return query, set(range(len(entries))) return query, set(range(len(entries)))
# Convert simple file filters with no path separator into regex # Convert simple file filters with no path separator into regex
# e.g. "file:notes.org" -> "file:.*notes.org" # e.g. "file:notes.org" -> "file:.*notes.org"
files_to_search = [] files_to_search = []
for file in sorted(raw_files_to_search): for file in sorted(raw_files_to_search):
if '/' not in file and '\\' not in file and '*' not in file: if '/' not in file and '\\' not in file and '*' not in file:
files_to_search += [f'*{file}'] files_to_search += [f'*{file}']
else: else:
files_to_search += [file] files_to_search += [file]
end = time.time()
logger.debug(f"Extract files_to_search from query: {end - start} seconds")
# Return item from cache if exists # Return item from cache if exists
query = re.sub(self.file_filter_regex, '', query).strip() query = re.sub(self.file_filter_regex, '', query).strip()
@@ -61,17 +57,13 @@ class FileFilter(BaseFilter):
self.load(entries, regenerate=False) self.load(entries, regenerate=False)
# Mark entries that contain any blocked_words for exclusion # Mark entries that contain any blocked_words for exclusion
start = time.time() with timer("Mark entries satisfying filter", logger):
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file] for entry_file in self.file_to_entry_map.keys()
for entry_file in self.file_to_entry_map.keys() for search_file in files_to_search
for search_file in files_to_search if fnmatch.fnmatch(entry_file, search_file)], set())
if fnmatch.fnmatch(entry_file, search_file)], set()) if not included_entry_indices:
if not included_entry_indices: return query, {}
return query, {}
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# Cache results # Cache results
self.cache[cache_key] = included_entry_indices self.cache[cache_key] = included_entry_indices

View File

@@ -6,7 +6,7 @@ from collections import defaultdict
# Internal Packages # Internal Packages
from src.search_filter.base_filter import BaseFilter from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU from src.utils.helpers import LRU, timer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -24,17 +24,15 @@ class WordFilter(BaseFilter):
def load(self, entries, *args, **kwargs): def load(self, entries, *args, **kwargs):
start = time.time() with timer("Created word filter index", logger):
self.cache = {} # Clear cache on filter (re-)load self.cache = {} # Clear cache on filter (re-)load
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
# Create map of words to entries they exist in # Create map of words to entries they exist in
for entry_index, entry in enumerate(entries): for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()): for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
if word == '': if word == '':
continue continue
self.word_to_entry_index[word].add(entry_index) self.word_to_entry_index[word].add(entry_index)
end = time.time()
logger.debug(f"Created word filter index: {end - start} seconds")
return self.word_to_entry_index return self.word_to_entry_index
@@ -50,14 +48,10 @@ class WordFilter(BaseFilter):
def apply(self, query, entries): def apply(self, query, entries):
"Find entries containing required and not blocked words specified in query" "Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters # Separate natural query from required, blocked words filters
start = time.time() with timer("Extract required, blocked filters from query", logger):
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
required_words = set([word.lower() for word in re.findall(self.required_regex, query)]) blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)]) query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
end = time.time()
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
if len(required_words) == 0 and len(blocked_words) == 0: if len(required_words) == 0 and len(blocked_words) == 0:
return query, set(range(len(entries))) return query, set(range(len(entries)))
@@ -72,20 +66,16 @@ class WordFilter(BaseFilter):
if not self.word_to_entry_index: if not self.word_to_entry_index:
self.load(entries, regenerate=False) self.load(entries, regenerate=False)
start = time.time()
# mark entries that contain all required_words for inclusion # mark entries that contain all required_words for inclusion
entries_with_all_required_words = set(range(len(entries))) with timer("Mark entries satisfying filter", logger):
if len(required_words) > 0: entries_with_all_required_words = set(range(len(entries)))
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words]) if len(required_words) > 0:
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
# mark entries that contain any blocked_words for exclusion # mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set() entries_with_any_blocked_words = set()
if len(blocked_words) > 0: if len(blocked_words) > 0:
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words]) entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# get entries satisfying inclusion and exclusion filters # get entries satisfying inclusion and exclusion filters
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words

View File

@@ -13,7 +13,7 @@ from tqdm import trange
import torch import torch
# Internal Packages # Internal Packages
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
from src.utils.config import ImageSearchModel from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
@@ -147,27 +147,21 @@ def query(raw_query, count, model: ImageSearchModel):
logger.info(f"Find Images by Text: {query}") logger.info(f"Find Images by Text: {query}")
# Now we encode the query (which can either be an image or a text string) # Now we encode the query (which can either be an image or a text string)
start = time.time() with timer("Query Encode Time", logger):
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False) query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
end = time.time()
logger.debug(f"Query Encode Time: {end - start:.3f} seconds")
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
start = time.time() with timer("Search Time", logger):
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']} image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
for result for result
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]} in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
end = time.time()
logger.debug(f"Search Time: {end - start:.3f} seconds")
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings: if model.image_metadata_embeddings:
start = time.time() with timer("Metadata Search Time", logger):
metadata_hits = {result['corpus_id']: result['score'] metadata_hits = {result['corpus_id']: result['score']
for result for result
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]} in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
end = time.time()
logger.debug(f"Metadata Search Time: {end - start:.3f} seconds")
# Sum metadata, image scores of the highest ranked images # Sum metadata, image scores of the highest ranked images
for corpus_id, score in metadata_hits.items(): for corpus_id, score in metadata_hits.items():

View File

@@ -12,7 +12,7 @@ from src.search_filter.base_filter import BaseFilter
# Internal Packages # Internal Packages
from src.utils import state from src.utils import state
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
from src.utils.config import TextSearchModel from src.utils.config import TextSearchModel
from src.utils.models import BaseEncoder from src.utils.models import BaseEncoder
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
@@ -96,6 +96,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
# Filter query, entries and embeddings before semantic search # Filter query, entries and embeddings before semantic search
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters) query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
# If no entries left after filtering, return empty results # If no entries left after filtering, return empty results
if entries is None or len(entries) == 0: if entries is None or len(entries) == 0:
return [], [] return [], []
@@ -105,17 +106,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
return hits, entries return hits, entries
# Encode the query using the bi-encoder # Encode the query using the bi-encoder
start = time.time() with timer("Query Encode Time", logger, state.device):
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device) question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding) question_embedding = util.normalize_embeddings(question_embedding)
end = time.time()
logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
# Find relevant entries for the query # Find relevant entries for the query
start = time.time() with timer("Search Time", logger, state.device):
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
end = time.time()
logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
if rank_results: if rank_results:
@@ -170,36 +167,29 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]: def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]:
'''Filter query, entries and embeddings before semantic search''' '''Filter query, entries and embeddings before semantic search'''
start_filter = time.time()
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
for filter in filters_in_query:
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters with timer("Total Filter Time", logger, state.device):
if not included_entry_indices: included_entry_indices = set(range(len(entries)))
return '', [], torch.tensor([], device=state.device) filters_in_query = [filter for filter in filters if filter.can_filter(query)]
else: for filter in filters_in_query:
start = time.time() query, included_entry_indices_by_filter = filter.apply(query, entries)
entries = [entries[id] for id in included_entry_indices] included_entry_indices.intersection_update(included_entry_indices_by_filter)
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
end = time.time()
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
end_filter = time.time() # Get entries (and associated embeddings) satisfying all filters
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}") if not included_entry_indices:
return '', [], torch.tensor([], device=state.device)
else:
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
return query, entries, corpus_embeddings return query, entries, corpus_embeddings
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]: def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]:
'''Score all retrieved entries using the cross-encoder''' '''Score all retrieved entries using the cross-encoder'''
start = time.time() with timer("Cross-Encoder Predict Time", logger, state.device):
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits] cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp) cross_scores = cross_encoder.predict(cross_inp)
end = time.time()
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
# Store cross-encoder scores in results dictionary for ranking # Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)): for idx in range(len(cross_scores)):
@@ -210,12 +200,10 @@ def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[E
def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]: def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
'''Order results by cross-encoder score followed by bi-encoder score''' '''Order results by cross-encoder score followed by bi-encoder score'''
start = time.time() with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
if rank_results: if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time()
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
return hits return hits
@@ -223,11 +211,12 @@ def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]:
'''Deduplicate entries by raw entry text before showing to users '''Deduplicate entries by raw entry text before showing to users
Compiled entries are split by max tokens supported by ML models. Compiled entries are split by max tokens supported by ML models.
This can result in duplicate hits, entries shown to user.''' This can result in duplicate hits, entries shown to user.'''
start = time.time()
seen, original_hits_count = set(), len(hits) with timer("Deduplication Time", logger, state.device):
hits = [hit for hit in hits seen, original_hits_count = set(), len(hits)
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)] hits = [hit for hit in hits
duplicate_hits = original_hits_count - len(hits) if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
end = time.time() duplicate_hits = original_hits_count - len(hits)
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
logger.debug(f"Removed {duplicate_hits} duplicates")
return hits return hits

View File

@@ -2,10 +2,12 @@
from __future__ import annotations # to avoid quoting type hints from __future__ import annotations # to avoid quoting type hints
import logging import logging
import sys import sys
import torch
from collections import OrderedDict from collections import OrderedDict
from importlib import import_module from importlib import import_module
from os.path import join from os.path import join
from pathlib import Path from pathlib import Path
from time import perf_counter
from typing import Optional, Union, TYPE_CHECKING from typing import Optional, Union, TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -81,6 +83,25 @@ def get_class_by_name(name: str) -> object:
return getattr(import_module(module_name), class_name) return getattr(import_module(module_name), class_name)
class timer:
'''Context manager to log time taken for a block of code to run'''
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
self.message = message
self.logger = logger
self.device = device
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, *_):
elapsed = perf_counter() - self.start
if self.device is None:
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds")
else:
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}")
class LRU(OrderedDict): class LRU(OrderedDict):
def __init__(self, *args, capacity=128, **kwargs): def __init__(self, *args, capacity=128, **kwargs):
self.capacity = capacity self.capacity = capacity