mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 13:23:15 +00:00
Trace query response performance and display timings in verbose mode
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
import argparse
|
||||
import pathlib
|
||||
from copy import deepcopy
|
||||
import time
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
@@ -62,38 +63,62 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
||||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = []):
|
||||
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = [], verbose=0):
|
||||
"Search for entries that answer the query"
|
||||
# Copy original embeddings, entries to filter them for query
|
||||
start = time.time()
|
||||
query = raw_query
|
||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||
entries = deepcopy(model.entries)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Copy Time: {end - start:.3f} seconds")
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
start = time.time()
|
||||
for filter in filters:
|
||||
query, entries, corpus_embeddings = filter(query, entries, corpus_embeddings)
|
||||
if entries is None or len(entries) == 0:
|
||||
return [], []
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Filter Time: {end - start:.3f} seconds")
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
start = time.time()
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True)
|
||||
question_embedding.to(device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Query Encode Time: {end - start:.3f} seconds")
|
||||
|
||||
# Find relevant entries for the query
|
||||
start = time.time()
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Search Time: {end - start:.3f} seconds")
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
start = time.time()
|
||||
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]['cross-score'] = cross_scores[idx]
|
||||
|
||||
# Order results by cross-encoder score followed by bi-encoder score
|
||||
start = time.time()
|
||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Rank Time: {end - start:.3f} seconds")
|
||||
|
||||
return hits, entries
|
||||
|
||||
|
||||
Reference in New Issue
Block a user