Trace query response performance and display timings in verbose mode

This commit is contained in:
Debanjum Singh Solanky
2022-07-26 21:03:53 +04:00
parent d8efcd559f
commit f094c86204
3 changed files with 64 additions and 14 deletions

View File

@@ -2,6 +2,7 @@
import argparse
import pathlib
from copy import deepcopy
import time
# External Packages
import torch
@@ -62,38 +63,62 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = []):
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = [], verbose=0):
"Search for entries that answer the query"
# Copy original embeddings, entries to filter them for query
start = time.time()
query = raw_query
corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries)
end = time.time()
if verbose > 1:
print(f"Copy Time: {end - start:.3f} seconds")
# Filter query, entries and embeddings before semantic search
start = time.time()
for filter in filters:
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings)
if entries is None or len(entries) == 0:
return [], []
end = time.time()
if verbose > 1:
print(f"Filter Time: {end - start:.3f} seconds")
# Encode the query using the bi-encoder
start = time.time()
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
question_embedding.to(device)
question_embedding = util.normalize_embeddings(question_embedding)
end = time.time()
if verbose > 1:
print(f"Query Encode Time: {end - start:.3f} seconds")
# Find relevant entries for the query
start = time.time()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
end = time.time()
if verbose > 1:
print(f"Search Time: {end - start:.3f} seconds")
# Score all retrieved entries using the cross-encoder
start = time.time()
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time()
if verbose > 1:
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
# Order results by cross-encoder score followed by bi-encoder score
start = time.time()
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time()
if verbose > 1:
print(f"Rank Time: {end - start:.3f} seconds")
return hits, entries