From eace7c621532f6e095fa8fd0fee5b5c37a5bc987 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 9 Jan 2023 12:49:11 -0300 Subject: [PATCH] Use torch.tensor as torch.Tensor cannot create tensor on MPS device - `torch.Tensor' is apparently a legacy tensor constructor - Using that to create tensor on MPS devices throws error: RuntimeError: legacy constructor expects device type: cpu but device type: mps was passed - `torch.tensor' can handle creating tensors on Mac GPU (MPS) fine --- src/search_type/text_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 816972cc..1929eb18 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -69,7 +69,7 @@ def compute_embeddings(entries_with_ids: list[tuple[int, Entry]], bi_encoder, em if new_entries: new_embeddings = bi_encoder.encode(new_entries, convert_to_tensor=True, device=state.device, show_progress_bar=True) existing_entry_ids = [id for id, _ in entries_with_ids if id != -1] - existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)) if existing_entry_ids else torch.Tensor(device=state.device) + existing_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(existing_entry_ids, device=state.device)) if existing_entry_ids else torch.tensor([], device=state.device) corpus_embeddings = torch.cat([existing_embeddings, new_embeddings], dim=0) # Else compute the corpus embeddings from scratch else: