mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 13:22:12 +00:00
Make cross-encoder re-rank results if query param set on /search API
- Improve search speed by ~10x
Tested on corpus of 125K lines, 12.5K entries
- Allow cross-encoder to re-rank results by settings &?r=true when querying /search API
- It's an optional param that default to False
- Earlier all results were re-ranked by cross-encoder
- Making this configurable allows for much faster results, if desired
but for lower accuracy
This commit is contained in:
@@ -63,7 +63,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
||||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = [], verbose=0):
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0):
|
||||
"Search for entries that answer the query"
|
||||
query = raw_query
|
||||
|
||||
@@ -108,21 +108,23 @@ def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list =
|
||||
print(f"Search Time: {end - start:.3f} seconds")
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
start = time.time()
|
||||
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
|
||||
if rank_results:
|
||||
start = time.time()
|
||||
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]['cross-score'] = cross_scores[idx]
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]['cross-score'] = cross_scores[idx]
|
||||
|
||||
# Order results by cross-encoder score followed by bi-encoder score
|
||||
start = time.time()
|
||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Rank Time: {end - start:.3f} seconds")
|
||||
@@ -152,7 +154,7 @@ def collate_results(hits, entries, count=5):
|
||||
return [
|
||||
{
|
||||
"entry": entries[hit['corpus_id']]['raw'],
|
||||
"score": f"{hit['cross-score']:.3f}"
|
||||
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
|
||||
}
|
||||
for hit
|
||||
in hits[0:count]]
|
||||
|
||||
Reference in New Issue
Block a user