mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-04 21:29:12 +00:00
Create and use a context manager to time code
Use the timer context manager in all places where code was being timed - Benefits - Deduplicate timing code scattered across codebase. - Provides single place to manage perf timing code - Use consistent timing log patterns
This commit is contained in:
@@ -13,7 +13,7 @@ from tqdm import trange
|
||||
import torch
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model
|
||||
from src.utils.helpers import get_absolute_path, get_from_dict, resolve_absolute_path, load_model, timer
|
||||
from src.utils.config import ImageSearchModel
|
||||
from src.utils.rawconfig import ImageContentConfig, ImageSearchConfig, SearchResponse
|
||||
|
||||
@@ -147,27 +147,21 @@ def query(raw_query, count, model: ImageSearchModel):
|
||||
logger.info(f"Find Images by Text: {query}")
|
||||
|
||||
# Now we encode the query (which can either be an image or a text string)
|
||||
start = time.time()
|
||||
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
end = time.time()
|
||||
logger.debug(f"Query Encode Time: {end - start:.3f} seconds")
|
||||
with timer("Query Encode Time", logger):
|
||||
query_embedding = model.image_encoder.encode([query], convert_to_tensor=True, show_progress_bar=False)
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||
start = time.time()
|
||||
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
|
||||
for result
|
||||
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
|
||||
end = time.time()
|
||||
logger.debug(f"Search Time: {end - start:.3f} seconds")
|
||||
with timer("Search Time", logger):
|
||||
image_hits = {result['corpus_id']: {'image_score': result['score'], 'score': result['score']}
|
||||
for result
|
||||
in util.semantic_search(query_embedding, model.image_embeddings, top_k=count)[0]}
|
||||
|
||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image metadata embeddings.
|
||||
if model.image_metadata_embeddings:
|
||||
start = time.time()
|
||||
metadata_hits = {result['corpus_id']: result['score']
|
||||
for result
|
||||
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
|
||||
end = time.time()
|
||||
logger.debug(f"Metadata Search Time: {end - start:.3f} seconds")
|
||||
with timer("Metadata Search Time", logger):
|
||||
metadata_hits = {result['corpus_id']: result['score']
|
||||
for result
|
||||
in util.semantic_search(query_embedding, model.image_metadata_embeddings, top_k=count)[0]}
|
||||
|
||||
# Sum metadata, image scores of the highest ranked images
|
||||
for corpus_id, score in metadata_hits.items():
|
||||
|
||||
@@ -12,7 +12,7 @@ from src.search_filter.base_filter import BaseFilter
|
||||
|
||||
# Internal Packages
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model
|
||||
from src.utils.helpers import get_absolute_path, is_none_or_empty, resolve_absolute_path, load_model, timer
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.models import BaseEncoder
|
||||
from src.utils.rawconfig import SearchResponse, TextSearchConfig, TextContentConfig, Entry
|
||||
@@ -96,6 +96,7 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
query, entries, corpus_embeddings = apply_filters(query, entries, corpus_embeddings, model.filters)
|
||||
|
||||
# If no entries left after filtering, return empty results
|
||||
if entries is None or len(entries) == 0:
|
||||
return [], []
|
||||
@@ -105,17 +106,13 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
|
||||
return hits, entries
|
||||
|
||||
# Encode the query using the bi-encoder
|
||||
start = time.time()
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
end = time.time()
|
||||
logger.debug(f"Query Encode Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True, device=state.device)
|
||||
question_embedding = util.normalize_embeddings(question_embedding)
|
||||
|
||||
# Find relevant entries for the query
|
||||
start = time.time()
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||
end = time.time()
|
||||
logger.debug(f"Search Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
with timer("Search Time", logger, state.device):
|
||||
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0]
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
if rank_results:
|
||||
@@ -170,36 +167,29 @@ def setup(text_to_jsonl: Type[TextToJsonl], config: TextContentConfig, search_co
|
||||
|
||||
def apply_filters(query: str, entries: list[Entry], corpus_embeddings: torch.Tensor, filters: list[BaseFilter]) -> tuple[str, list[Entry], torch.Tensor]:
|
||||
'''Filter query, entries and embeddings before semantic search'''
|
||||
start_filter = time.time()
|
||||
included_entry_indices = set(range(len(entries)))
|
||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||
for filter in filters_in_query:
|
||||
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
||||
|
||||
# Get entries (and associated embeddings) satisfying all filters
|
||||
if not included_entry_indices:
|
||||
return '', [], torch.tensor([], device=state.device)
|
||||
else:
|
||||
start = time.time()
|
||||
entries = [entries[id] for id in included_entry_indices]
|
||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
|
||||
end = time.time()
|
||||
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
|
||||
with timer("Total Filter Time", logger, state.device):
|
||||
included_entry_indices = set(range(len(entries)))
|
||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||
for filter in filters_in_query:
|
||||
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
||||
|
||||
end_filter = time.time()
|
||||
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}")
|
||||
# Get entries (and associated embeddings) satisfying all filters
|
||||
if not included_entry_indices:
|
||||
return '', [], torch.tensor([], device=state.device)
|
||||
else:
|
||||
entries = [entries[id] for id in included_entry_indices]
|
||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
|
||||
|
||||
return query, entries, corpus_embeddings
|
||||
|
||||
|
||||
def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[Entry], hits: list[dict]) -> list[dict]:
|
||||
'''Score all retrieved entries using the cross-encoder'''
|
||||
start = time.time()
|
||||
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
||||
cross_scores = cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
logger.debug(f"Cross-Encoder Predict Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_inp = [[query, entries[hit['corpus_id']].compiled] for hit in hits]
|
||||
cross_scores = cross_encoder.predict(cross_inp)
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
@@ -210,12 +200,10 @@ def cross_encoder_score(cross_encoder: CrossEncoder, query: str, entries: list[E
|
||||
|
||||
def sort_results(rank_results: bool, hits: list[dict]) -> list[dict]:
|
||||
'''Order results by cross-encoder score followed by bi-encoder score'''
|
||||
start = time.time()
|
||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
end = time.time()
|
||||
logger.debug(f"Rank Time: {end - start:.3f} seconds on device: {state.device}")
|
||||
with timer("Rank Time", logger, state.device):
|
||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
return hits
|
||||
|
||||
|
||||
@@ -223,11 +211,12 @@ def deduplicate_results(entries: list[Entry], hits: list[dict]) -> list[dict]:
|
||||
'''Deduplicate entries by raw entry text before showing to users
|
||||
Compiled entries are split by max tokens supported by ML models.
|
||||
This can result in duplicate hits, entries shown to user.'''
|
||||
start = time.time()
|
||||
seen, original_hits_count = set(), len(hits)
|
||||
hits = [hit for hit in hits
|
||||
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
|
||||
duplicate_hits = original_hits_count - len(hits)
|
||||
end = time.time()
|
||||
logger.debug(f"Deduplication Time: {end - start:.3f} seconds. Removed {duplicate_hits} duplicates")
|
||||
|
||||
with timer("Deduplication Time", logger, state.device):
|
||||
seen, original_hits_count = set(), len(hits)
|
||||
hits = [hit for hit in hits
|
||||
if entries[hit['corpus_id']].raw not in seen and not seen.add(entries[hit['corpus_id']].raw)]
|
||||
duplicate_hits = original_hits_count - len(hits)
|
||||
|
||||
logger.debug(f"Removed {duplicate_hits} duplicates")
|
||||
return hits
|
||||
|
||||
Reference in New Issue
Block a user