mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 21:29:11 +00:00
Fix GPU usage by Khoj on Macs to speed up search and indexing
- Ensure all tensors are on MPS device before doing operations across them
- Background
- GPU is used by default for Khoj on MacOS now
- Needed PyTorch > 1.13.0 on Macs to use GPU, which we do now
- MPS should speed up search and indexing on MacOS
This commit is contained in:
@@ -68,7 +68,7 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, em
|
|||||||
if new_entries:
|
if new_entries:
|
||||||
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True)
|
||||||
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
existing_entry_ids = [id for id, _ in entries_with_ids if id != -1]
|
||||||
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids)) if existing_entry_ids else torch.Tensor()
|
existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)) if existing_entry_ids else torch.Tensor(device=state.device)
|
||||||
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0)
|
||||||
# Else compute the corpus embeddings from scratch
|
# Else compute the corpus embeddings from scratch
|
||||||
else:
|
else:
|
||||||
@@ -102,12 +102,12 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
|||||||
else:
|
else:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
entries = [entries[id] for id in included_entry_indices]
|
entries = [entries[id] for id in included_entry_indices]
|
||||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices)))
|
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices), device=state.device))
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
|
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
|
||||||
|
|
||||||
end_filter = time.time()
|
end_filter = time.time()
|
||||||
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds")
|
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds on device: {state.device}")
|
||||||
|
|
||||||
if entries is None or len(entries) == 0:
|
if entries is None or len(entries) == 0:
|
||||||
return [], []
|
return [], []
|
||||||
|
|||||||
Reference in New Issue
Block a user