From d55d7d53dca111caa7d5cfe2fb3ab6f508e9d6ef Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 5 Jan 2023 15:10:34 -0300 Subject: [PATCH] Fix GPU usage by Khoj on Macs to speed up search and indexing - Ensure all tensors are on MPS device before doing operations across them - Background - GPU is used by default for Khoj on MacOS now - Needed PyTorch > 1.13.0 on Macs to use GPU, which we do now - MPS should speed up search and indexing on MacOS --- src/search_type/text_search.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 65b234a2..53eb3c3d 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -68,7 +68,7 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, em if new_entries: new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] - existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor() + existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)) if existing_entry_ids else torch.Tensor(device=state.device) corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) # Else compute the corpus embeddings from scratch else: @@ -102,12 +102,12 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False): else: start = time.time() entries = [entries[id] for id in included_entry_indices] - corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices))) + corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device)) end = time.time() logger.debug(f"Keep entries satisfying all filters: {end - start} seconds") end_filter = time.time() - logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds") + logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}") if entries is None or len(entries) == 0: return [], []