6
6
# LICENSE file in the root directory of this source tree.
7
7
8
8
# pyre-strict
9
-
10
9
# pyre-ignore-all-errors[56]
11
10
12
11
import unittest
12
+ from typing import Tuple , Type
13
13
14
14
import hypothesis .strategies as st
15
15
import numpy as np
20
20
21
21
if open_source :
22
22
# pyre-ignore[21]
23
- from test_utils import gpu_available
23
+ from test_utils import cpu_and_maybe_gpu , gpu_available
24
24
else :
25
25
import fbgemm_gpu .sparse_ops # noqa: F401, E402
26
- from fbgemm_gpu .test .test_utils import gpu_available
26
+ from fbgemm_gpu .test .test_utils import cpu_and_maybe_gpu , gpu_available
27
27
28
28
29
29
class CumSumTest (unittest .TestCase ):
30
30
@given (
31
31
n = st .integers (min_value = 0 , max_value = 10 ),
32
- long_index = st .booleans (),
32
+ index_types = st .sampled_from (
33
+ [
34
+ (torch .int64 , np .int64 ),
35
+ (torch .int32 , np .int32 ),
36
+ (torch .float32 , np .float32 ),
37
+ ]
38
+ ),
39
+ device = cpu_and_maybe_gpu (),
33
40
)
34
41
@settings (verbosity = Verbosity .verbose , max_examples = 20 , deadline = None )
35
- def test_cumsum (self , n : int , long_index : bool ) -> None :
36
- index_dtype = torch .int64 if long_index else torch .int32
37
- np_index_dtype = np .int64 if long_index else np .int32
42
+ def test_cumsum (
43
+ self ,
44
+ n : int ,
45
+ index_types : Tuple [Type [object ], Type [object ]],
46
+ device : torch .device ,
47
+ ) -> None :
48
+ (pt_index_dtype , np_index_dtype ) = index_types
49
+
50
+ # The CPU variants of asynchronous_*_cumsum support floats, since some
51
+ # downstream tests appear to be relying on this behavior. As such, the
52
+ # test is disabled for GPU + float test cases.
53
+ if device == torch .device ("cuda" ) and pt_index_dtype is torch .float32 :
54
+ return
38
55
39
- # cpu tests
40
- x = torch .randint (low = 0 , high = 100 , size = (n ,)).type (index_dtype )
56
+ # pyre-ignore-errors[16]
57
+ x = torch .randint (low = 0 , high = 100 , size = (n ,)).type (pt_index_dtype ). to ( device )
41
58
ze = torch .ops .fbgemm .asynchronous_exclusive_cumsum (x )
42
59
zi = torch .ops .fbgemm .asynchronous_inclusive_cumsum (x )
43
60
zc = torch .ops .fbgemm .asynchronous_complete_cumsum (x )
61
+
44
62
torch .testing .assert_close (
45
63
torch .from_numpy (np .cumsum (x .cpu ().numpy ()).astype (np_index_dtype )),
46
64
zi .cpu (),
@@ -59,68 +77,59 @@ def test_cumsum(self, n: int, long_index: bool) -> None:
59
77
)
60
78
61
79
# meta tests
62
- mx = torch .randint (low = 0 , high = 100 , size = (n ,)).type (index_dtype ).to ("meta" )
80
+ # pyre-ignore-errors[16]
81
+ mx = torch .randint (low = 0 , high = 100 , size = (n ,)).type (pt_index_dtype ).to ("meta" )
82
+
63
83
mze = torch .ops .fbgemm .asynchronous_exclusive_cumsum (mx )
64
84
self .assertEqual (ze .size (), mze .size ())
65
- # mzi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(mx)
66
- # self.assertEqual(zi.size(), mzi.size())
85
+
86
+ mzi = torch .ops .fbgemm .asynchronous_inclusive_cumsum (mx )
87
+ self .assertEqual (zi .size (), mzi .size ())
88
+
67
89
mzc = torch .ops .fbgemm .asynchronous_complete_cumsum (mx )
68
90
self .assertEqual (zc .size (), mzc .size ())
69
91
70
- if gpu_available :
71
- x = x .cuda ()
72
- ze = torch .ops .fbgemm .asynchronous_exclusive_cumsum (x )
73
- zi = torch .ops .fbgemm .asynchronous_inclusive_cumsum (x )
74
- zc = torch .ops .fbgemm .asynchronous_complete_cumsum (x )
75
- torch .testing .assert_close (
76
- torch .from_numpy (np .cumsum (x .cpu ().numpy ()).astype (np_index_dtype )),
77
- zi .cpu (),
78
- )
79
- torch .testing .assert_close (
80
- torch .from_numpy (
81
- (np .cumsum ([0 ] + x .cpu ().numpy ().tolist ())[:- 1 ]).astype (
82
- np_index_dtype
83
- )
84
- ),
85
- ze .cpu (),
86
- )
87
- torch .testing .assert_close (
88
- torch .from_numpy (
89
- (np .cumsum ([0 ] + x .cpu ().numpy ().tolist ())).astype (np_index_dtype )
90
- ),
91
- zc .cpu (),
92
- )
93
-
94
92
@given (
95
93
n = st .integers (min_value = 0 , max_value = 60 ),
96
94
b = st .integers (min_value = 0 , max_value = 10 ),
97
- long_index = st .booleans (),
95
+ index_types = st .sampled_from (
96
+ [
97
+ (torch .int64 , np .int64 ),
98
+ (torch .int32 , np .int32 ),
99
+ (torch .float32 , np .float32 ),
100
+ ]
101
+ ),
102
+ device = cpu_and_maybe_gpu (),
98
103
)
99
104
@settings (verbosity = Verbosity .verbose , max_examples = 20 , deadline = None )
100
105
def test_asynchronous_complete_cumsum_2d (
101
- self , n : int , b : int , long_index : bool
106
+ self ,
107
+ n : int ,
108
+ b : int ,
109
+ index_types : Tuple [Type [object ], Type [object ]],
110
+ device : torch .device ,
102
111
) -> None :
103
- index_dtype = torch . int64 if long_index else torch . int32
104
-
105
- def test_asynchronous_complete_cumsum_2d_helper ( x : torch . Tensor ) -> None :
106
- np_index_dtype = np . int64 if long_index else np . int32
107
- zc = torch . ops . fbgemm . asynchronous_complete_cumsum ( x )
108
- zeros = torch .zeros ( b , 1 )
109
- torch . testing . assert_close (
110
- torch . from_numpy (
111
- np . cumsum (
112
- torch .concat ([ zeros , x . cpu ()], dim = 1 ). numpy (), axis = 1
113
- ). astype ( np_index_dtype )
114
- ),
115
- zc . cpu (),
116
- )
117
-
118
- x = torch .randint ( low = 0 , high = 100 , size = ( b , n )). type ( index_dtype )
119
- # cpu test
120
- test_asynchronous_complete_cumsum_2d_helper ( x )
121
- if gpu_available :
122
- # gpu test
123
- test_asynchronous_complete_cumsum_2d_helper ( x . cuda () )
112
+ ( pt_index_dtype , np_index_dtype ) = index_types
113
+
114
+ # The CPU variants of asynchronous_*_cumsum support floats, since some
115
+ # downstream tests appear to be relying on this behavior. As such, the
116
+ # test is disabled for GPU + float test cases.
117
+ if device == torch .device ( "cuda" ) and pt_index_dtype is torch . float32 :
118
+ return
119
+
120
+ # pyre-ignore-errors[16]
121
+ x = torch .randint ( low = 0 , high = 100 , size = ( b , n )). type ( pt_index_dtype ). to ( device )
122
+
123
+ zc = torch . ops . fbgemm . asynchronous_complete_cumsum ( x )
124
+ zeros = torch . zeros ( b , 1 )
125
+ torch . testing . assert_close (
126
+ torch . from_numpy (
127
+ np . cumsum ( torch .concat ([ zeros , x . cpu ()], dim = 1 ). numpy (), axis = 1 ). astype (
128
+ np_index_dtype
129
+ )
130
+ ),
131
+ zc . cpu (),
132
+ )
124
133
125
134
126
135
extend_test_class (CumSumTest )
0 commit comments