Make cross-encoder re-rank results if query param set on /search API

- Improve search speed by ~10x
  Tested on corpus of 125K lines, 12.5K entries

- Allow cross-encoder to re-rank results by settings &?r=true when querying /search API
  - It's an optional param that default to False
  - Earlier all results were re-ranked by cross-encoder
  - Making this configurable allows for much faster results, if desired
    but for lower accuracy
This commit is contained in:
Debanjum Singh Solanky
2022-07-26 22:56:36 +04:00
parent b1e64fd4a8
commit 1168244c92
4 changed files with 22 additions and 19 deletions

View File

@@ -63,7 +63,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = [], verbose=0):
def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0):
"Search for entries that answer the query"
query = raw_query
@@ -108,21 +108,23 @@ def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list =
print(f"Search Time: {end - start:.3f} seconds")
# Score all retrieved entries using the cross-encoder
start = time.time()
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time()
if verbose > 1:
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
if rank_results:
start = time.time()
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp)
end = time.time()
if verbose > 1:
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
# Store cross-encoder scores in results dictionary for ranking
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
# Order results by cross-encoder score followed by bi-encoder score
start = time.time()
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time()
if verbose > 1:
print(f"Rank Time: {end - start:.3f} seconds")
@@ -152,7 +154,7 @@ def collate_results(hits, entries, count=5):
return [
{
"entry": entries[hit['corpus_id']]['raw'],
"score": f"{hit['cross-score']:.3f}"
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
}
for hit
in hits[0:count]]