From acc909126003adde64bd283a7def30c3f91d365d Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 20 Aug 2022 13:21:21 +0300 Subject: [PATCH] Use MPS on Apple Mac M1 to GPU accelerate Encode, Query Performance - Note: Support for MPS in Pytorch is currently in v1.13.0 nightly builds - Users will have to wait for PyTorch MPS support to land in stable builds - Until then the code can be tweaked and tested to make use of the GPU acceleration on newer Macs --- src/utils/state.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/utils/state.py b/src/utils/state.py index 4f194331..b5c082d6 100644 --- a/src/utils/state.py +++ b/src/utils/state.py @@ -1,3 +1,5 @@ +# Standard Packages +from packaging import version # External Packages import torch from pathlib import Path @@ -12,7 +14,15 @@ model = SearchModels() processor_config = ProcessorConfigModel() config_file: Path = "" verbose: int = 0 -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # Set device to GPU if available host: str = None port: int = None -cli_args = None \ No newline at end of file +cli_args = None + +if torch.cuda.is_available(): + # Use CUDA GPU + device = torch.device("cuda:0") +elif version.parse(torch.__version__) >= version.parse("1.13.0.dev") and torch.backends.mps.is_available(): + # Use Apple M1 Metal Acceleration + device = torch.device("mps") +else: + device = torch.device("cpu")