Skip to content

Commit 8d9d269

Browse files
q10facebook-github-bot
authored andcommitted
Replace AT_DISPATCH with FBGEMM_DISPATCH, pt 4 (#2385)
Summary: Pull Request resolved: #2385 - Replace AT_DISPATCH with FBGEMM_DISPATCH, pt 4 Reviewed By: spcyppt Differential Revision: D54501814
1 parent d3a6166 commit 8d9d269

File tree

3 files changed

+71
-62
lines changed

3 files changed

+71
-62
lines changed

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,7 +1183,7 @@ Tensor asynchronous_exclusive_cumsum_cpu(const Tensor& t_in) {
11831183

11841184
const auto t_in_contig = t_in.expect_contiguous();
11851185
auto output = native_empty_like(*t_in_contig);
1186-
FBGEMM_DISPATCH_INTEGRAL_TYPES(
1186+
FBGEMM_DISPATCH_ALL_TYPES(
11871187
t_in_contig->scalar_type(),
11881188
"asynchronous_exclusive_cumsum_cpu_kernel",
11891189
[&] {
@@ -1200,7 +1200,7 @@ Tensor asynchronous_inclusive_cumsum_cpu(const Tensor& t_in) {
12001200

12011201
const auto t_in_contig = t_in.expect_contiguous();
12021202
auto output = native_empty_like(*t_in_contig);
1203-
FBGEMM_DISPATCH_INTEGRAL_TYPES(
1203+
FBGEMM_DISPATCH_ALL_TYPES(
12041204
t_in_contig->scalar_type(),
12051205
"asynchronous_inclusive_cumsum_cpu_kernel",
12061206
[&] {

fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
210210
const dim3 blocks(cuda_calc_xblock_count(
211211
reordered_cat_ad_offsets.numel() - 1,
212212
NUM_WARPS)); // one warp per sample
213-
FBGEMM_DISPATCH_INTEGRAL_TYPES(
213+
FBGEMM_DISPATCH_ALL_TYPES(
214214
cat_ad_indices.scalar_type(), "narrow_broadcast_indices_kernel_1", [&] {
215215
AT_DISPATCH_INDEX_TYPES(
216216
cat_ad_offsets.scalar_type(),

fbgemm_gpu/test/sparse/cumsum_test.py

Lines changed: 68 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-strict
9-
109
# pyre-ignore-all-errors[56]
1110

1211
import unittest
12+
from typing import Tuple, Type
1313

1414
import hypothesis.strategies as st
1515
import numpy as np
@@ -20,27 +20,45 @@
2020

2121
if open_source:
2222
# pyre-ignore[21]
23-
from test_utils import gpu_available
23+
from test_utils import cpu_and_maybe_gpu, gpu_available
2424
else:
2525
import fbgemm_gpu.sparse_ops # noqa: F401, E402
26-
from fbgemm_gpu.test.test_utils import gpu_available
26+
from fbgemm_gpu.test.test_utils import cpu_and_maybe_gpu, gpu_available
2727

2828

2929
class CumSumTest(unittest.TestCase):
3030
@given(
3131
n=st.integers(min_value=0, max_value=10),
32-
long_index=st.booleans(),
32+
index_types=st.sampled_from(
33+
[
34+
(torch.int64, np.int64),
35+
(torch.int32, np.int32),
36+
(torch.float32, np.float32),
37+
]
38+
),
39+
device=cpu_and_maybe_gpu(),
3340
)
3441
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
35-
def test_cumsum(self, n: int, long_index: bool) -> None:
36-
index_dtype = torch.int64 if long_index else torch.int32
37-
np_index_dtype = np.int64 if long_index else np.int32
42+
def test_cumsum(
43+
self,
44+
n: int,
45+
index_types: Tuple[Type[object], Type[object]],
46+
device: torch.device,
47+
) -> None:
48+
(pt_index_dtype, np_index_dtype) = index_types
49+
50+
# The CPU variants of asynchronous_*_cumsum support floats, since some
51+
# downstream tests appear to be relying on this behavior. As such, the
52+
# test is disabled for GPU + float test cases.
53+
if device == torch.device("cuda") and pt_index_dtype is torch.float32:
54+
return
3855

39-
# cpu tests
40-
x = torch.randint(low=0, high=100, size=(n,)).type(index_dtype)
56+
# pyre-ignore-errors[16]
57+
x = torch.randint(low=0, high=100, size=(n,)).type(pt_index_dtype).to(device)
4158
ze = torch.ops.fbgemm.asynchronous_exclusive_cumsum(x)
4259
zi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(x)
4360
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
61+
4462
torch.testing.assert_close(
4563
torch.from_numpy(np.cumsum(x.cpu().numpy()).astype(np_index_dtype)),
4664
zi.cpu(),
@@ -59,68 +77,59 @@ def test_cumsum(self, n: int, long_index: bool) -> None:
5977
)
6078

6179
# meta tests
62-
mx = torch.randint(low=0, high=100, size=(n,)).type(index_dtype).to("meta")
80+
# pyre-ignore-errors[16]
81+
mx = torch.randint(low=0, high=100, size=(n,)).type(pt_index_dtype).to("meta")
82+
6383
mze = torch.ops.fbgemm.asynchronous_exclusive_cumsum(mx)
6484
self.assertEqual(ze.size(), mze.size())
65-
# mzi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(mx)
66-
# self.assertEqual(zi.size(), mzi.size())
85+
86+
mzi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(mx)
87+
self.assertEqual(zi.size(), mzi.size())
88+
6789
mzc = torch.ops.fbgemm.asynchronous_complete_cumsum(mx)
6890
self.assertEqual(zc.size(), mzc.size())
6991

70-
if gpu_available:
71-
x = x.cuda()
72-
ze = torch.ops.fbgemm.asynchronous_exclusive_cumsum(x)
73-
zi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(x)
74-
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
75-
torch.testing.assert_close(
76-
torch.from_numpy(np.cumsum(x.cpu().numpy()).astype(np_index_dtype)),
77-
zi.cpu(),
78-
)
79-
torch.testing.assert_close(
80-
torch.from_numpy(
81-
(np.cumsum([0] + x.cpu().numpy().tolist())[:-1]).astype(
82-
np_index_dtype
83-
)
84-
),
85-
ze.cpu(),
86-
)
87-
torch.testing.assert_close(
88-
torch.from_numpy(
89-
(np.cumsum([0] + x.cpu().numpy().tolist())).astype(np_index_dtype)
90-
),
91-
zc.cpu(),
92-
)
93-
9492
@given(
9593
n=st.integers(min_value=0, max_value=60),
9694
b=st.integers(min_value=0, max_value=10),
97-
long_index=st.booleans(),
95+
index_types=st.sampled_from(
96+
[
97+
(torch.int64, np.int64),
98+
(torch.int32, np.int32),
99+
(torch.float32, np.float32),
100+
]
101+
),
102+
device=cpu_and_maybe_gpu(),
98103
)
99104
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
100105
def test_asynchronous_complete_cumsum_2d(
101-
self, n: int, b: int, long_index: bool
106+
self,
107+
n: int,
108+
b: int,
109+
index_types: Tuple[Type[object], Type[object]],
110+
device: torch.device,
102111
) -> None:
103-
index_dtype = torch.int64 if long_index else torch.int32
104-
105-
def test_asynchronous_complete_cumsum_2d_helper(x: torch.Tensor) -> None:
106-
np_index_dtype = np.int64 if long_index else np.int32
107-
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
108-
zeros = torch.zeros(b, 1)
109-
torch.testing.assert_close(
110-
torch.from_numpy(
111-
np.cumsum(
112-
torch.concat([zeros, x.cpu()], dim=1).numpy(), axis=1
113-
).astype(np_index_dtype)
114-
),
115-
zc.cpu(),
116-
)
117-
118-
x = torch.randint(low=0, high=100, size=(b, n)).type(index_dtype)
119-
# cpu test
120-
test_asynchronous_complete_cumsum_2d_helper(x)
121-
if gpu_available:
122-
# gpu test
123-
test_asynchronous_complete_cumsum_2d_helper(x.cuda())
112+
(pt_index_dtype, np_index_dtype) = index_types
113+
114+
# The CPU variants of asynchronous_*_cumsum support floats, since some
115+
# downstream tests appear to be relying on this behavior. As such, the
116+
# test is disabled for GPU + float test cases.
117+
if device == torch.device("cuda") and pt_index_dtype is torch.float32:
118+
return
119+
120+
# pyre-ignore-errors[16]
121+
x = torch.randint(low=0, high=100, size=(b, n)).type(pt_index_dtype).to(device)
122+
123+
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
124+
zeros = torch.zeros(b, 1)
125+
torch.testing.assert_close(
126+
torch.from_numpy(
127+
np.cumsum(torch.concat([zeros, x.cpu()], dim=1).numpy(), axis=1).astype(
128+
np_index_dtype
129+
)
130+
),
131+
zc.cpu(),
132+
)
124133

125134

126135
extend_test_class(CumSumTest)

0 commit comments

Comments
 (0)