Skip to content

Commit 14781cb

Browse files
NripeshNrusty1s
andauthored
Add apple silicon GPU Acceleration Support ("mps") (#335)
* Add apple silicon GPU Acceleration Support * format fix * update * update * update * Fix module not found error --------- Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
1 parent 40693ab commit 14781cb

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

benchmark/main.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
import time
2-
import os.path as osp
1+
import argparse
32
import itertools
3+
import os.path as osp
4+
import time
45

5-
import argparse
6-
import wget
76
import torch
7+
import wget
88
from scipy.io import loadmat
9-
109
from torch_scatter import scatter_add
10+
1111
from torch_sparse.tensor import SparseTensor
1212

1313
short_rows = [
@@ -62,6 +62,9 @@ def time_func(func, x):
6262
try:
6363
if torch.cuda.is_available():
6464
torch.cuda.synchronize()
65+
elif torch.backends.mps.is_available():
66+
import torch.mps
67+
torch.mps.synchronize()
6568
t = time.perf_counter()
6669

6770
if not args.with_backward:
@@ -77,6 +80,9 @@ def time_func(func, x):
7780

7881
if torch.cuda.is_available():
7982
torch.cuda.synchronize()
83+
elif torch.backends.mps.is_available():
84+
import torch.mps
85+
torch.mps.synchronize()
8086
return time.perf_counter() - t
8187
except RuntimeError as e:
8288
if 'out of memory' not in str(e):

torch_sparse/testing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
devices = [torch.device('cpu')]
1717
if torch.cuda.is_available():
1818
devices += [torch.device('cuda:0')]
19+
if torch.backends.mps.is_available():
20+
devices += [torch.device('mps')]
1921

2022

2123
def tensor(x: Any, dtype: torch.dtype, device: torch.device):

0 commit comments

Comments
 (0)