File tree Expand file tree Collapse file tree 2 files changed +13
-5
lines changed Expand file tree Collapse file tree 2 files changed +13
-5
lines changed Original file line number Diff line number Diff line change 1
- import time
2
- import os .path as osp
1
+ import argparse
3
2
import itertools
3
+ import os .path as osp
4
+ import time
4
5
5
- import argparse
6
- import wget
7
6
import torch
7
+ import wget
8
8
from scipy .io import loadmat
9
-
10
9
from torch_scatter import scatter_add
10
+
11
11
from torch_sparse .tensor import SparseTensor
12
12
13
13
short_rows = [
@@ -62,6 +62,9 @@ def time_func(func, x):
62
62
try :
63
63
if torch .cuda .is_available ():
64
64
torch .cuda .synchronize ()
65
+ elif torch .backends .mps .is_available ():
66
+ import torch .mps
67
+ torch .mps .synchronize ()
65
68
t = time .perf_counter ()
66
69
67
70
if not args .with_backward :
@@ -77,6 +80,9 @@ def time_func(func, x):
77
80
78
81
if torch .cuda .is_available ():
79
82
torch .cuda .synchronize ()
83
+ elif torch .backends .mps .is_available ():
84
+ import torch .mps
85
+ torch .mps .synchronize ()
80
86
return time .perf_counter () - t
81
87
except RuntimeError as e :
82
88
if 'out of memory' not in str (e ):
Original file line number Diff line number Diff line change 16
16
devices = [torch .device ('cpu' )]
17
17
if torch .cuda .is_available ():
18
18
devices += [torch .device ('cuda:0' )]
19
+ if torch .backends .mps .is_available ():
20
+ devices += [torch .device ('mps' )]
19
21
20
22
21
23
def tensor (x : Any , dtype : torch .dtype , device : torch .device ):
You can’t perform that action at this time.
0 commit comments