7
7
8
8
# pyre-strict
9
9
10
- # pyre-ignore-all-errors[56]
10
+ # pyre-ignore-all-errors[53, 56]
11
11
12
12
import unittest
13
+ from typing import Tuple , Type
13
14
14
15
import hypothesis .strategies as st
15
16
import numpy as np
20
21
21
22
if open_source :
22
23
# pyre-ignore[21]
23
- from test_utils import gpu_available
24
+ from test_utils import cpu_and_maybe_gpu , gpu_available
24
25
else :
25
26
import fbgemm_gpu .sparse_ops # noqa: F401, E402
26
- from fbgemm_gpu .test .test_utils import gpu_available
27
+ from fbgemm_gpu .test .test_utils import cpu_and_maybe_gpu , gpu_available
27
28
28
29
29
30
class CumSumTest (unittest .TestCase ):
30
31
@given (
31
32
n = st .integers (min_value = 0 , max_value = 10 ),
32
- long_index = st .booleans (),
33
+ index_types = st .sampled_from (
34
+ [
35
+ (torch .int64 , np .int64 ),
36
+ (torch .int32 , np .int32 ),
37
+ (torch .float32 , np .float32 ),
38
+ ]
39
+ ),
40
+ device = cpu_and_maybe_gpu (),
33
41
)
34
42
@settings (verbosity = Verbosity .verbose , max_examples = 20 , deadline = None )
35
- def test_cumsum (self , n : int , long_index : bool ) -> None :
36
- index_dtype = torch .int64 if long_index else torch .int32
37
- np_index_dtype = np .int64 if long_index else np .int32
43
+ def test_cumsum (
44
+ self ,
45
+ n : int ,
46
+ index_types : Tuple [Type [object ], Type [object ]],
47
+ device : torch .device ,
48
+ ) -> None :
49
+ (pt_index_dtype , np_index_dtype ) = index_types
50
+
51
+ # The CPU variants of asynchronous_*_cumsum support floats, since some
52
+ # downstream tests appear to be relying on this behavior. As such, the
53
+ # test is disabled for GPU + float test cases.
54
+ if device == torch .device ("cuda" ) and pt_index_dtype is torch .float32 :
55
+ return
38
56
39
- # cpu tests
40
- x = torch .randint (low = 0 , high = 100 , size = (n ,)).type (index_dtype )
57
+ # pyre-ignore-errors[16]
58
+ x = torch .randint (low = 0 , high = 100 , size = (n ,)).type (pt_index_dtype ). to ( device )
41
59
ze = torch .ops .fbgemm .asynchronous_exclusive_cumsum (x )
42
60
zi = torch .ops .fbgemm .asynchronous_inclusive_cumsum (x )
43
61
zc = torch .ops .fbgemm .asynchronous_complete_cumsum (x )
62
+
44
63
torch .testing .assert_close (
45
64
torch .from_numpy (np .cumsum (x .cpu ().numpy ()).astype (np_index_dtype )),
46
65
zi .cpu (),
@@ -59,68 +78,59 @@ def test_cumsum(self, n: int, long_index: bool) -> None:
59
78
)
60
79
61
80
# meta tests
62
- mx = torch .randint (low = 0 , high = 100 , size = (n ,)).type (index_dtype ).to ("meta" )
81
+ # pyre-ignore-errors[16]
82
+ mx = torch .randint (low = 0 , high = 100 , size = (n ,)).type (pt_index_dtype ).to ("meta" )
83
+
63
84
mze = torch .ops .fbgemm .asynchronous_exclusive_cumsum (mx )
64
85
self .assertEqual (ze .size (), mze .size ())
65
- # mzi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(mx)
66
- # self.assertEqual(zi.size(), mzi.size())
86
+
87
+ mzi = torch .ops .fbgemm .asynchronous_inclusive_cumsum (mx )
88
+ self .assertEqual (zi .size (), mzi .size ())
89
+
67
90
mzc = torch .ops .fbgemm .asynchronous_complete_cumsum (mx )
68
91
self .assertEqual (zc .size (), mzc .size ())
69
92
70
- if gpu_available :
71
- x = x .cuda ()
72
- ze = torch .ops .fbgemm .asynchronous_exclusive_cumsum (x )
73
- zi = torch .ops .fbgemm .asynchronous_inclusive_cumsum (x )
74
- zc = torch .ops .fbgemm .asynchronous_complete_cumsum (x )
75
- torch .testing .assert_close (
76
- torch .from_numpy (np .cumsum (x .cpu ().numpy ()).astype (np_index_dtype )),
77
- zi .cpu (),
78
- )
79
- torch .testing .assert_close (
80
- torch .from_numpy (
81
- (np .cumsum ([0 ] + x .cpu ().numpy ().tolist ())[:- 1 ]).astype (
82
- np_index_dtype
83
- )
84
- ),
85
- ze .cpu (),
86
- )
87
- torch .testing .assert_close (
88
- torch .from_numpy (
89
- (np .cumsum ([0 ] + x .cpu ().numpy ().tolist ())).astype (np_index_dtype )
90
- ),
91
- zc .cpu (),
92
- )
93
-
94
93
@given (
95
94
n = st .integers (min_value = 0 , max_value = 60 ),
96
95
b = st .integers (min_value = 0 , max_value = 10 ),
97
- long_index = st .booleans (),
96
+ index_types = st .sampled_from (
97
+ [
98
+ (torch .int64 , np .int64 ),
99
+ (torch .int32 , np .int32 ),
100
+ (torch .float32 , np .float32 ),
101
+ ]
102
+ ),
103
+ device = cpu_and_maybe_gpu (),
98
104
)
99
105
@settings (verbosity = Verbosity .verbose , max_examples = 20 , deadline = None )
100
106
def test_asynchronous_complete_cumsum_2d (
101
- self , n : int , b : int , long_index : bool
107
+ self ,
108
+ n : int ,
109
+ b : int ,
110
+ index_types : Tuple [Type [object ], Type [object ]],
111
+ device : torch .device ,
102
112
) -> None :
103
- index_dtype = torch . int64 if long_index else torch . int32
104
-
105
- def test_asynchronous_complete_cumsum_2d_helper ( x : torch . Tensor ) -> None :
106
- np_index_dtype = np . int64 if long_index else np . int32
107
- zc = torch . ops . fbgemm . asynchronous_complete_cumsum ( x )
108
- zeros = torch .zeros ( b , 1 )
109
- torch . testing . assert_close (
110
- torch . from_numpy (
111
- np . cumsum (
112
- torch .concat ([ zeros , x . cpu ()], dim = 1 ). numpy (), axis = 1
113
- ). astype ( np_index_dtype )
114
- ),
115
- zc . cpu (),
116
- )
117
-
118
- x = torch .randint ( low = 0 , high = 100 , size = ( b , n )). type ( index_dtype )
119
- # cpu test
120
- test_asynchronous_complete_cumsum_2d_helper ( x )
121
- if gpu_available :
122
- # gpu test
123
- test_asynchronous_complete_cumsum_2d_helper ( x . cuda () )
113
+ ( pt_index_dtype , np_index_dtype ) = index_types
114
+
115
+ # The CPU variants of asynchronous_*_cumsum support floats, since some
116
+ # downstream tests appear to be relying on this behavior. As such, the
117
+ # test is disabled for GPU + float test cases.
118
+ if device == torch .device ( "cuda" ) and pt_index_dtype is torch . float32 :
119
+ return
120
+
121
+ # pyre-ignore-errors[16]
122
+ x = torch .randint ( low = 0 , high = 100 , size = ( b , n )). type ( pt_index_dtype ). to ( device )
123
+
124
+ zc = torch . ops . fbgemm . asynchronous_complete_cumsum ( x )
125
+ zeros = torch . zeros ( b , 1 )
126
+ torch . testing . assert_close (
127
+ torch . from_numpy (
128
+ np . cumsum ( torch .concat ([ zeros , x . cpu ()], dim = 1 ). numpy (), axis = 1 ). astype (
129
+ np_index_dtype
130
+ )
131
+ ),
132
+ zc . cpu (),
133
+ )
124
134
125
135
126
136
extend_test_class (CumSumTest )
0 commit comments