mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-03 13:19:16 +00:00
- Note: Support for MPS in Pytorch is currently in v1.13.0 nightly builds - Users will have to wait for PyTorch MPS support to land in stable builds - Until then the code can be tweaked and tested to make use of the GPU acceleration on newer Macs
29 lines
745 B
Python
29 lines
745 B
Python
# Standard Packages
|
|
from packaging import version
|
|
# External Packages
|
|
import torch
|
|
from pathlib import Path
|
|
|
|
# Internal Packages
|
|
from src.utils.config import SearchModels, ProcessorConfigModel
|
|
from src.utils.rawconfig import FullConfig
|
|
|
|
# Application Global State
|
|
config = FullConfig()
|
|
model = SearchModels()
|
|
processor_config = ProcessorConfigModel()
|
|
config_file: Path = ""
|
|
verbose: int = 0
|
|
host: str = None
|
|
port: int = None
|
|
cli_args = None
|
|
|
|
if torch.cuda.is_available():
|
|
# Use CUDA GPU
|
|
device = torch.device("cuda:0")
|
|
elif version.parse(torch.__version__) >= version.parse("1.13.0.dev") and torch.backends.mps.is_available():
|
|
# Use Apple M1 Metal Acceleration
|
|
device = torch.device("mps")
|
|
else:
|
|
device = torch.device("cpu")
|