Create and use a context manager to time code

Use the timer context manager in all places where code was being timed

- Benefits
  - Deduplicate timing code scattered across codebase.
  - Provides single place to manage perf timing code
  - Use consistent timing log patterns
This commit is contained in:
Debanjum Singh Solanky
2023-01-09 19:43:19 -03:00
parent 93f39dbd43
commit aa22d83172
11 changed files with 235 additions and 298 deletions

View File

@@ -6,7 +6,7 @@ import time
# Internal Packages
from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.helpers import get_absolute_path, is_none_or_empty, timer
from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import Entry
@@ -30,38 +30,30 @@ class BeancountToJsonl(TextToJsonl):
beancount_files = BeancountToJsonl.get_beancount_files(beancount_files, beancount_file_filter)
# Extract Entries from specified Beancount files
start = time.time()
current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files))
end = time.time()
logger.debug(f"Parse transactions from Beancount files into dictionaries: {end - start} seconds")
with timer("Parse transactions from Beancount files into dictionaries", logger):
current_entries = BeancountToJsonl.convert_transactions_to_maps(*BeancountToJsonl.extract_beancount_transactions(beancount_files))
# Split entries by max tokens supported by model
start = time.time()
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
end = time.time()
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
# Identify, mark and merge any new entries with previous entries
start = time.time()
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time()
logger.debug(f"Identify new or updated transaction: {end - start} seconds")
with timer("Identify new or updated transaction", logger):
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
# Process Each Entry from All Notes Files
start = time.time()
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries)
with timer("Write transactions to JSONL file", logger):
# Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = BeancountToJsonl.convert_transaction_maps_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write transactions to JSONL file: {end - start} seconds")
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
return entries_with_ids

View File

@@ -6,7 +6,7 @@ import time
# Internal Packages
from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.helpers import get_absolute_path, is_none_or_empty, timer
from src.utils.constants import empty_escape_sequences
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import Entry
@@ -30,38 +30,30 @@ class MarkdownToJsonl(TextToJsonl):
markdown_files = MarkdownToJsonl.get_markdown_files(markdown_files, markdown_file_filter)
# Extract Entries from specified Markdown files
start = time.time()
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files))
end = time.time()
logger.debug(f"Parse entries from Markdown files into dictionaries: {end - start} seconds")
with timer("Parse entries from Markdown files into dictionaries", logger):
current_entries = MarkdownToJsonl.convert_markdown_entries_to_maps(*MarkdownToJsonl.extract_markdown_entries(markdown_files))
# Split entries by max tokens supported by model
start = time.time()
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
end = time.time()
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
# Identify, mark and merge any new entries with previous entries
start = time.time()
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
end = time.time()
logger.debug(f"Identify new or updated entries: {end - start} seconds")
with timer("Identify new or updated entries", logger):
if not previous_entries:
entries_with_ids = list(enumerate(current_entries))
else:
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
# Process Each Entry from All Notes Files
start = time.time()
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
with timer("Write markdown entries to JSONL file", logger):
# Process Each Entry from All Notes Files
entries = list(map(lambda entry: entry[1], entries_with_ids))
jsonl_data = MarkdownToJsonl.convert_markdown_maps_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write markdown entries to JSONL file: {end - start} seconds")
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
return entries_with_ids

View File

@@ -7,7 +7,7 @@ from typing import Iterable
# Internal Packages
from src.processor.org_mode import orgnode
from src.processor.text_to_jsonl import TextToJsonl
from src.utils.helpers import get_absolute_path, is_none_or_empty
from src.utils.helpers import get_absolute_path, is_none_or_empty, timer
from src.utils.jsonl import dump_jsonl, compress_jsonl_data
from src.utils.rawconfig import Entry
from src.utils import state
@@ -29,24 +29,18 @@ class OrgToJsonl(TextToJsonl):
exit(1)
# Get Org Files to Process
start = time.time()
org_files = OrgToJsonl.get_org_files(org_files, org_file_filter)
with timer("Get org files to process", logger):
org_files = OrgToJsonl.get_org_files(org_files, org_file_filter)
# Extract Entries from specified Org files
start = time.time()
entry_nodes, file_to_entries = self.extract_org_entries(org_files)
end = time.time()
logger.debug(f"Parse entries from org files into OrgNode objects: {end - start} seconds")
with timer("Parse entries from org files into OrgNode objects", logger):
entry_nodes, file_to_entries = self.extract_org_entries(org_files)
start = time.time()
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
end = time.time()
logger.debug(f"Convert OrgNodes into list of entries: {end - start} seconds")
with timer("Convert OrgNodes into list of entries", logger):
current_entries = self.convert_org_nodes_to_entries(entry_nodes, file_to_entries, index_heading_entries)
start = time.time()
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
end = time.time()
logger.debug(f"Split entries by max token size supported by model: {end - start} seconds")
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
# Identify, mark and merge any new entries with previous entries
if not previous_entries:
@@ -55,17 +49,15 @@ class OrgToJsonl(TextToJsonl):
entries_with_ids = self.mark_entries_for_update(current_entries, previous_entries, key='compiled', logger=logger)
# Process Each Entry from All Notes Files
start = time.time()
entries = map(lambda entry: entry[1], entries_with_ids)
jsonl_data = self.convert_org_entries_to_jsonl(entries)
with timer("Write org entries to JSONL file", logger):
entries = map(lambda entry: entry[1], entries_with_ids)
jsonl_data = self.convert_org_entries_to_jsonl(entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
end = time.time()
logger.debug(f"Write org entries to JSONL file: {end - start} seconds")
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
compress_jsonl_data(jsonl_data, output_file)
elif output_file.suffix == ".jsonl":
dump_jsonl(jsonl_data, output_file)
return entries_with_ids

View File

@@ -4,6 +4,7 @@ import hashlib
import time
import logging
from typing import Callable
from src.utils.helpers import timer
# Internal Packages
from src.utils.rawconfig import Entry, TextContentConfig
@@ -40,35 +41,31 @@ class TextToJsonl(ABC):
def mark_entries_for_update(self, current_entries: list[Entry], previous_entries: list[Entry], key='compiled', logger=None) -> list[tuple[int, Entry]]:
# Hash all current and previous entries to identify new entries
start = time.time()
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries))
end = time.time()
logger.debug(f"Hash previous, current entries: {end - start} seconds")
with timer("Hash previous, current entries", logger):
current_entry_hashes = list(map(TextToJsonl.hash_func(key), current_entries))
previous_entry_hashes = list(map(TextToJsonl.hash_func(key), previous_entries))
start = time.time()
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
with timer("Identify, Mark, Combine new, existing entries", logger):
hash_to_current_entries = dict(zip(current_entry_hashes, current_entries))
hash_to_previous_entries = dict(zip(previous_entry_hashes, previous_entries))
# All entries that did not exist in the previous set are to be added
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
# All entries that exist in both current and previous sets are kept
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# All entries that did not exist in the previous set are to be added
new_entry_hashes = set(current_entry_hashes) - set(previous_entry_hashes)
# All entries that exist in both current and previous sets are kept
existing_entry_hashes = set(current_entry_hashes) & set(previous_entry_hashes)
# Mark new entries with -1 id to flag for later embeddings generation
new_entries = [
(-1, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
for entry_hash in existing_entry_hashes
]
# Mark new entries with -1 id to flag for later embeddings generation
new_entries = [
(-1, hash_to_current_entries[entry_hash])
for entry_hash in new_entry_hashes
]
# Set id of existing entries to their previous ids to reuse their existing encoded embeddings
existing_entries = [
(previous_entry_hashes.index(entry_hash), hash_to_previous_entries[entry_hash])
for entry_hash in existing_entry_hashes
]
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries
end = time.time()
logger.debug(f"Identify, Mark, Combine new, existing entries: {end - start} seconds")
existing_entries_sorted = sorted(existing_entries, key=lambda e: e[0])
entries_with_ids = existing_entries_sorted + new_entries
return entries_with_ids

View File

@@ -10,6 +10,7 @@ from fastapi import APIRouter
# Internal Packages
from src.configure import configure_processor, configure_search
from src.search_type import image_search, text_search
from src.utils.helpers import timer
from src.utils.rawconfig import FullConfig, SearchResponse
from src.utils.config import SearchType
from src.utils import state, constants
@@ -47,7 +48,6 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
# initialize variables
user_query = q.strip()
results_count = n
query_start, query_end, collate_start, collate_end = None, None, None, None
# return cached results, if available
query_cache_key = f'{user_query}-{n}-{t}-{r}'
@@ -57,73 +57,58 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
# query org-mode notes
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
query_end = time.time()
with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count)
if (t == SearchType.Music or t == None) and state.model.music_search:
# query music library
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
query_end = time.time()
with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count)
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
# query markdown files
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
query_end = time.time()
with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count)
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
query_end = time.time()
with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
# collate and return results
collate_start = time.time()
results = text_search.collate_results(hits, entries, results_count)
collate_end = time.time()
with timer("Collating results took", logger):
results = text_search.collate_results(hits, entries, results_count)
if (t == SearchType.Image or t == None) and state.model.image_search:
# query images
query_start = time.time()
hits = image_search.query(user_query, results_count, state.model.image_search)
output_directory = constants.web_directory / 'images'
query_end = time.time()
with timer("Query took", logger):
hits = image_search.query(user_query, results_count, state.model.image_search)
output_directory = constants.web_directory / 'images'
# collate and return results
collate_start = time.time()
results = image_search.collate_results(
hits,
image_names=state.model.image_search.image_names,
output_directory=output_directory,
image_files_url='/static/images',
count=results_count)
collate_end = time.time()
with timer("Collating results took", logger):
results = image_search.collate_results(
hits,
image_names=state.model.image_search.image_names,
output_directory=output_directory,
image_files_url='/static/images',
count=results_count)
# Cache results
state.query_cache[query_cache_key] = results
if query_start and query_end:
logger.debug(f"Query took {query_end - query_start:.3f} seconds")
if collate_start and collate_end:
logger.debug(f"Collating results took {collate_end - collate_start:.3f} seconds")
return results

View File

@@ -12,7 +12,7 @@ import dateparser as dtparse
# Internal Packages
from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU
from src.utils.helpers import LRU, timer
logger = logging.getLogger(__name__)
@@ -34,19 +34,16 @@ class DateFilter(BaseFilter):
def load(self, entries, *args, **kwargs):
start = time.time()
for id, entry in enumerate(entries):
# Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
# Convert date string in entry to unix timestamp
try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
except ValueError:
continue
self.date_to_entry_ids[date_in_entry].add(id)
end = time.time()
logger.debug(f"Created date filter index: {end - start} seconds")
with timer("Created date filter index", logger):
for id, entry in enumerate(entries):
# Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', getattr(entry, self.entry_key)):
# Convert date string in entry to unix timestamp
try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
except ValueError:
continue
self.date_to_entry_ids[date_in_entry].add(id)
def can_filter(self, raw_query):
"Check if query contains date filters"
@@ -56,10 +53,8 @@ class DateFilter(BaseFilter):
def apply(self, query, entries):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
start = time.time()
query_daterange = self.extract_date_range(query)
end = time.time()
logger.debug(f"Extract date range to filter from query: {end - start} seconds")
with timer("Extract date range to filter from query", logger):
query_daterange = self.extract_date_range(query)
# if no date in query, return all entries
if query_daterange is None:
@@ -80,14 +75,12 @@ class DateFilter(BaseFilter):
self.load(entries)
# find entries containing any dates that fall with date range specified in query
start = time.time()
entries_to_include = set()
for date_in_entry in self.date_to_entry_ids.keys():
# Check if date in entry is within date range specified in query
if query_daterange[0] <= date_in_entry < query_daterange[1]:
entries_to_include |= self.date_to_entry_ids[date_in_entry]
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
with timer("Mark entries satisfying filter", logger):
entries_to_include = set()
for date_in_entry in self.date_to_entry_ids.keys():
# Check if date in entry is within date range specified in query
if query_daterange[0] <= date_in_entry < query_daterange[1]:
entries_to_include |= self.date_to_entry_ids[date_in_entry]
# cache results
self.cache[cache_key] = entries_to_include

View File

@@ -7,7 +7,7 @@ from collections import defaultdict
# Internal Packages
from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU
from src.utils.helpers import LRU, timer
logger = logging.getLogger(__name__)
@@ -22,32 +22,28 @@ class FileFilter(BaseFilter):
self.cache = LRU()
def load(self, entries, *args, **kwargs):
start = time.time()
for id, entry in enumerate(entries):
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
end = time.time()
logger.debug(f"Created file filter index: {end - start} seconds")
with timer("Created file filter index", logger):
for id, entry in enumerate(entries):
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
def can_filter(self, raw_query):
return re.search(self.file_filter_regex, raw_query) is not None
def apply(self, query, entries):
# Extract file filters from raw query
start = time.time()
raw_files_to_search = re.findall(self.file_filter_regex, query)
if not raw_files_to_search:
return query, set(range(len(entries)))
with timer("Extract files_to_search from query", logger):
raw_files_to_search = re.findall(self.file_filter_regex, query)
if not raw_files_to_search:
return query, set(range(len(entries)))
# Convert simple file filters with no path separator into regex
# e.g. "file:notes.org" -> "file:.*notes.org"
files_to_search = []
for file in sorted(raw_files_to_search):
if '/' not in file and '\\' not in file and '*' not in file:
files_to_search += [f'*{file}']
else:
files_to_search += [file]
end = time.time()
logger.debug(f"Extract files_to_search from query: {end - start} seconds")
# Convert simple file filters with no path separator into regex
# e.g. "file:notes.org" -> "file:.*notes.org"
files_to_search = []
for file in sorted(raw_files_to_search):
if '/' not in file and '\\' not in file and '*' not in file:
files_to_search += [f'*{file}']
else:
files_to_search += [file]
# Return item from cache if exists
query = re.sub(self.file_filter_regex, '', query).strip()
@@ -61,17 +57,13 @@ class FileFilter(BaseFilter):
self.load(entries, regenerate=False)
# Mark entries that contain any blocked_words for exclusion
start = time.time()
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
for entry_file in self.file_to_entry_map.keys()
for search_file in files_to_search
if fnmatch.fnmatch(entry_file, search_file)], set())
if not included_entry_indices:
return query, {}
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
with timer("Mark entries satisfying filter", logger):
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
for entry_file in self.file_to_entry_map.keys()
for search_file in files_to_search
if fnmatch.fnmatch(entry_file, search_file)], set())
if not included_entry_indices:
return query, {}
# Cache results
self.cache[cache_key] = included_entry_indices

View File

@@ -6,7 +6,7 @@ from collections import defaultdict
# Internal Packages
from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU
from src.utils.helpers import LRU, timer
logger = logging.getLogger(__name__)
@@ -24,17 +24,15 @@ class WordFilter(BaseFilter):
def load(self, entries, *args, **kwargs):
start = time.time()
self.cache = {} # Clear cache on filter (re-)load
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
# Create map of words to entries they exist in
for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
if word == '':
continue
self.word_to_entry_index[word].add(entry_index)
end = time.time()
logger.debug(f"Created word filter index: {end - start} seconds")
with timer("Created word filter index", logger):
self.cache = {} # Clear cache on filter (re-)load
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\''
# Create map of words to entries they exist in
for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, getattr(entry, self.entry_key).lower()):
if word == '':
continue
self.word_to_entry_index[word].add(entry_index)
return self.word_to_entry_index
@@ -50,14 +48,10 @@ class WordFilter(BaseFilter):
def apply(self, query, entries):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from required, blocked words filters
start = time.time()
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
end = time.time()
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
with timer("Extract required, blocked filters from query", logger):
required_words = set([word.lower() for word in re.findall(self.required_regex, query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', query)).strip()
if len(required_words) == 0 and len(blocked_words) == 0:
return query, set(range(len(entries)))
@@ -72,20 +66,16 @@ class WordFilter(BaseFilter):
if not self.word_to_entry_index:
self.load(entries, regenerate=False)
start = time.time()
# mark entries that contain all required_words for inclusion
entries_with_all_required_words = set(range(len(entries)))
if len(required_words) > 0:
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
with timer("Mark entries satisfying filter", logger):
entries_with_all_required_words = set(range(len(entries)))
if len(required_words) > 0:
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
# mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set()
if len(blocked_words) > 0:
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set()
if len(blocked_words) > 0:
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
# get entries satisfying inclusion and exclusion filters
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words

View File

@@ -13,7 +13,7 @@ from tqdm import trange
import torch
# Internal Packages
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
@@ -147,27 +147,21 @@ def query(raw_query, count, model: ImageSearchModel):
logger.info(f"Find Images by Text: {query}")
# Now we encode the query (which can either be an image or a text string)
start = time.time()
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
end = time.time()
logger.debug(f"Query Encode Time: {end - start:.3f} seconds")
with timer("Query Encode Time", logger):
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
start = time.time()
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
for result
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
end = time.time()
logger.debug(f"Search Time: {end - start:.3f} seconds")
with timer("Search Time", logger):
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
for result
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings:
start = time.time()
metadata_hits = {result['corpus_id']: result['score']
for result
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
end = time.time()
logger.debug(f"Metadata Search Time: {end - start:.3f} seconds")
with timer("Metadata Search Time", logger):
metadata_hits = {result['corpus_id']: result['score']
for result
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
# Sum metadata, image scores of the highest ranked images
for corpus_id, score in metadata_hits.items():

View File

@@ -12,7 +12,7 @@ from src.search_filter.base_filter import BaseFilter
# Internal Packages
from src.utils import state
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
from src.utils.config import TextSearchModel
from src.utils.models import BaseEncoder
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
@@ -96,6 +96,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
# Filter query, entries and embeddings before semantic search
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
# If no entries left after filtering, return empty results
if entries is None or len(entries) == 0:
return [], []
@@ -105,17 +106,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
return hits, entries
# Encode the query using the bi-encoder
start = time.time()
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
end = time.time()
logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
with timer("Query Encode Time", logger, state.device):
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
# Find relevant entries for the query
start = time.time()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
end = time.time()
logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
with timer("Search Time", logger, state.device):
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
# Score all retrieved entries using the cross-encoder
if rank_results:
@@ -170,36 +167,29 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]:
'''Filter query, entries and embeddings before semantic search'''
start_filter = time.time()
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
for filter in filters_in_query:
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return '', [], torch.tensor([], device=state.device)
else:
start = time.time()
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
end = time.time()
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
with timer("Total Filter Time", logger, state.device):
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
for filter in filters_in_query:
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
end_filter = time.time()
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}")
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return '', [], torch.tensor([], device=state.device)
else:
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
return query, entries, corpus_embeddings
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]:
'''Score all retrieved entries using the cross-encoder'''
start = time.time()
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
end = time.time()
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
@@ -210,12 +200,10 @@ def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[E
def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
'''Order results by cross-encoder score followed by bi-encoder score'''
start = time.time()
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time()
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
return hits
@@ -223,11 +211,12 @@ def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]:
'''Deduplicate entries by raw entry text before showing to users
Compiled entries are split by max tokens supported by ML models.
This can result in duplicate hits, entries shown to user.'''
start = time.time()
seen, original_hits_count = set(), len(hits)
hits = [hit for hit in hits
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
duplicate_hits = original_hits_count - len(hits)
end = time.time()
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
with timer("Deduplication Time", logger, state.device):
seen, original_hits_count = set(), len(hits)
hits = [hit for hit in hits
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
duplicate_hits = original_hits_count - len(hits)
logger.debug(f"Removed {duplicate_hits} duplicates")
return hits

View File

@@ -2,10 +2,12 @@
from __future__ import annotations # to avoid quoting type hints
import logging
import sys
import torch
from collections import OrderedDict
from importlib import import_module
from os.path import join
from pathlib import Path
from time import perf_counter
from typing import Optional, Union, TYPE_CHECKING
if TYPE_CHECKING:
@@ -81,6 +83,25 @@ def get_class_by_name(name: str) -> object:
return getattr(import_module(module_name), class_name)
class timer:
'''Context manager to log time taken for a block of code to run'''
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
self.message = message
self.logger = logger
self.device = device
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, *_):
elapsed = perf_counter() - self.start
if self.device is None:
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds")
else:
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}")
class LRU(OrderedDict):
def __init__(self, *args, capacity=128, **kwargs):
self.capacity = capacity