Create and use a context manager to time code

Use the timer context manager in all places where code was being timed

- Benefits
  - Deduplicate timing code scattered across codebase.
  - Provides single place to manage perf timing code
  - Use consistent timing log patterns
This commit is contained in:
Debanjum Singh Solanky
2023-01-09 19:43:19 -03:00
parent 93f39dbd43
commit aa22d83172
11 changed files with 235 additions and 298 deletions

View File

@@ -13,7 +13,7 @@ from tqdm import trange
import torch
# Internal Packages
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
@@ -147,27 +147,21 @@ def query(raw_query, count, model: ImageSearchModel):
logger.info(f"Find Images by Text: {query}")
# Now we encode the query (which can either be an image or a text string)
start = time.time()
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
end = time.time()
logger.debug(f"Query Encode Time: {end - start:.3f} seconds")
with timer("Query Encode Time", logger):
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
start = time.time()
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
for result
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
end = time.time()
logger.debug(f"Search Time: {end - start:.3f} seconds")
with timer("Search Time", logger):
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
for result
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
if model.image_metadata_embeddings:
start = time.time()
metadata_hits = {result['corpus_id']: result['score']
for result
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
end = time.time()
logger.debug(f"Metadata Search Time: {end - start:.3f} seconds")
with timer("Metadata Search Time", logger):
metadata_hits = {result['corpus_id']: result['score']
for result
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
# Sum metadata, image scores of the highest ranked images
for corpus_id, score in metadata_hits.items():

View File

@@ -12,7 +12,7 @@ from src.search_filter.base_filter import BaseFilter
# Internal Packages
from src.utils import state
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
from src.utils.config import TextSearchModel
from src.utils.models import BaseEncoder
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
@@ -96,6 +96,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
# Filter query, entries and embeddings before semantic search
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
# If no entries left after filtering, return empty results
if entries is None or len(entries) == 0:
return [], []
@@ -105,17 +106,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
return hits, entries
# Encode the query using the bi-encoder
start = time.time()
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
end = time.time()
logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
with timer("Query Encode Time", logger, state.device):
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
question_embedding = util.normalize_embeddings(question_embedding)
# Find relevant entries for the query
start = time.time()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
end = time.time()
logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
with timer("Search Time", logger, state.device):
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
# Score all retrieved entries using the cross-encoder
if rank_results:
@@ -170,36 +167,29 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]:
'''Filter query, entries and embeddings before semantic search'''
start_filter = time.time()
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
for filter in filters_in_query:
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return '', [], torch.tensor([], device=state.device)
else:
start = time.time()
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
end = time.time()
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
with timer("Total Filter Time", logger, state.device):
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
for filter in filters_in_query:
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
end_filter = time.time()
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}")
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return '', [], torch.tensor([], device=state.device)
else:
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
return query, entries, corpus_embeddings
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]:
'''Score all retrieved entries using the cross-encoder'''
start = time.time()
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
end = time.time()
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
@@ -210,12 +200,10 @@ def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[E
def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
'''Order results by cross-encoder score followed by bi-encoder score'''
start = time.time()
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time()
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
return hits
@@ -223,11 +211,12 @@ def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]:
'''Deduplicate entries by raw entry text before showing to users
Compiled entries are split by max tokens supported by ML models.
This can result in duplicate hits, entries shown to user.'''
start = time.time()
seen, original_hits_count = set(), len(hits)
hits = [hit for hit in hits
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
duplicate_hits = original_hits_count - len(hits)
end = time.time()
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
with timer("Deduplication Time", logger, state.device):
seen, original_hits_count = set(), len(hits)
hits = [hit for hit in hits
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
duplicate_hits = original_hits_count - len(hits)
logger.debug(f"Removed {duplicate_hits} duplicates")
return hits