diff --git a/src/utils/state.py b/src/utils/state.py index 4f194331..b5c082d6 100644 --- a/src/utils/state.py +++ b/src/utils/state.py @@ -1,3 +1,5 @@ +# Standard Packages +from packaging import version # External Packages import torch from pathlib import Path @@ -12,7 +14,15 @@ model = SearchModels() processor_config = ProcessorConfigModel() config_file: Path = "" verbose: int = 0 -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # Set device to GPU if available host: str = None port: int = None -cli_args = None \ No newline at end of file +cli_args = None + +if torch.cuda.is_available(): + # Use CUDA GPU + device = torch.device("cuda:0") +elif version.parse(torch.__version__) >= version.parse("1.13.0.dev") and torch.backends.mps.is_available(): + # Use Apple M1 Metal Acceleration + device = torch.device("mps") +else: + device = torch.device("cpu")