Skip to content

Commit 5914cd3

Browse files
q10facebook-github-bot
authored andcommitted
Replace AT_DISPATCH with FBGEMM_DISPATCH, pt 4 (#2385)
Summary: Pull Request resolved: #2385 - Replace AT_DISPATCH with FBGEMM_DISPATCH, pt 4 Reviewed By: spcyppt Differential Revision: D54501814
1 parent d3a6166 commit 5914cd3

File tree

3 files changed

+72
-62
lines changed

3 files changed

+72
-62
lines changed

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,7 +1183,7 @@ Tensor asynchronous_exclusive_cumsum_cpu(const Tensor& t_in) {
11831183

11841184
const auto t_in_contig = t_in.expect_contiguous();
11851185
auto output = native_empty_like(*t_in_contig);
1186-
FBGEMM_DISPATCH_INTEGRAL_TYPES(
1186+
FBGEMM_DISPATCH_ALL_TYPES(
11871187
t_in_contig->scalar_type(),
11881188
"asynchronous_exclusive_cumsum_cpu_kernel",
11891189
[&] {
@@ -1200,7 +1200,7 @@ Tensor asynchronous_inclusive_cumsum_cpu(const Tensor& t_in) {
12001200

12011201
const auto t_in_contig = t_in.expect_contiguous();
12021202
auto output = native_empty_like(*t_in_contig);
1203-
FBGEMM_DISPATCH_INTEGRAL_TYPES(
1203+
FBGEMM_DISPATCH_ALL_TYPES(
12041204
t_in_contig->scalar_type(),
12051205
"asynchronous_inclusive_cumsum_cpu_kernel",
12061206
[&] {

fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
210210
const dim3 blocks(cuda_calc_xblock_count(
211211
reordered_cat_ad_offsets.numel() - 1,
212212
NUM_WARPS)); // one warp per sample
213-
FBGEMM_DISPATCH_INTEGRAL_TYPES(
213+
FBGEMM_DISPATCH_ALL_TYPES(
214214
cat_ad_indices.scalar_type(), "narrow_broadcast_indices_kernel_1", [&] {
215215
AT_DISPATCH_INDEX_TYPES(
216216
cat_ad_offsets.scalar_type(),

fbgemm_gpu/test/sparse/cumsum_test.py

Lines changed: 69 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
# pyre-strict
99

10-
# pyre-ignore-all-errors[56]
10+
# pyre-ignore-all-errors[53,56]
1111

1212
import unittest
13+
from typing import Tuple, Type
1314

1415
import hypothesis.strategies as st
1516
import numpy as np
@@ -20,27 +21,45 @@
2021

2122
if open_source:
2223
# pyre-ignore[21]
23-
from test_utils import gpu_available
24+
from test_utils import cpu_and_maybe_gpu, gpu_available
2425
else:
2526
import fbgemm_gpu.sparse_ops # noqa: F401, E402
26-
from fbgemm_gpu.test.test_utils import gpu_available
27+
from fbgemm_gpu.test.test_utils import cpu_and_maybe_gpu, gpu_available
2728

2829

2930
class CumSumTest(unittest.TestCase):
3031
@given(
3132
n=st.integers(min_value=0, max_value=10),
32-
long_index=st.booleans(),
33+
index_types=st.sampled_from(
34+
[
35+
(torch.int64, np.int64),
36+
(torch.int32, np.int32),
37+
(torch.float32, np.float32),
38+
]
39+
),
40+
device=cpu_and_maybe_gpu(),
3341
)
3442
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
35-
def test_cumsum(self, n: int, long_index: bool) -> None:
36-
index_dtype = torch.int64 if long_index else torch.int32
37-
np_index_dtype = np.int64 if long_index else np.int32
43+
def test_cumsum(
44+
self,
45+
n: int,
46+
index_types: Tuple[Type[object], Type[object]],
47+
device: torch.device,
48+
) -> None:
49+
(pt_index_dtype, np_index_dtype) = index_types
50+
51+
# The CPU variants of asynchronous_*_cumsum support floats, since some
52+
# downstream tests appear to be relying on this behavior. As such, the
53+
# test is disabled for GPU + float test cases.
54+
if device == torch.device("cuda") and pt_index_dtype is torch.float32:
55+
return
3856

39-
# cpu tests
40-
x = torch.randint(low=0, high=100, size=(n,)).type(index_dtype)
57+
# pyre-ignore-errors[16]
58+
x = torch.randint(low=0, high=100, size=(n,)).type(pt_index_dtype).to(device)
4159
ze = torch.ops.fbgemm.asynchronous_exclusive_cumsum(x)
4260
zi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(x)
4361
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
62+
4463
torch.testing.assert_close(
4564
torch.from_numpy(np.cumsum(x.cpu().numpy()).astype(np_index_dtype)),
4665
zi.cpu(),
@@ -59,68 +78,59 @@ def test_cumsum(self, n: int, long_index: bool) -> None:
5978
)
6079

6180
# meta tests
62-
mx = torch.randint(low=0, high=100, size=(n,)).type(index_dtype).to("meta")
81+
# pyre-ignore-errors[16]
82+
mx = torch.randint(low=0, high=100, size=(n,)).type(pt_index_dtype).to("meta")
83+
6384
mze = torch.ops.fbgemm.asynchronous_exclusive_cumsum(mx)
6485
self.assertEqual(ze.size(), mze.size())
65-
# mzi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(mx)
66-
# self.assertEqual(zi.size(), mzi.size())
86+
87+
mzi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(mx)
88+
self.assertEqual(zi.size(), mzi.size())
89+
6790
mzc = torch.ops.fbgemm.asynchronous_complete_cumsum(mx)
6891
self.assertEqual(zc.size(), mzc.size())
6992

70-
if gpu_available:
71-
x = x.cuda()
72-
ze = torch.ops.fbgemm.asynchronous_exclusive_cumsum(x)
73-
zi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(x)
74-
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
75-
torch.testing.assert_close(
76-
torch.from_numpy(np.cumsum(x.cpu().numpy()).astype(np_index_dtype)),
77-
zi.cpu(),
78-
)
79-
torch.testing.assert_close(
80-
torch.from_numpy(
81-
(np.cumsum([0] + x.cpu().numpy().tolist())[:-1]).astype(
82-
np_index_dtype
83-
)
84-
),
85-
ze.cpu(),
86-
)
87-
torch.testing.assert_close(
88-
torch.from_numpy(
89-
(np.cumsum([0] + x.cpu().numpy().tolist())).astype(np_index_dtype)
90-
),
91-
zc.cpu(),
92-
)
93-
9493
@given(
9594
n=st.integers(min_value=0, max_value=60),
9695
b=st.integers(min_value=0, max_value=10),
97-
long_index=st.booleans(),
96+
index_types=st.sampled_from(
97+
[
98+
(torch.int64, np.int64),
99+
(torch.int32, np.int32),
100+
(torch.float32, np.float32),
101+
]
102+
),
103+
device=cpu_and_maybe_gpu(),
98104
)
99105
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
100106
def test_asynchronous_complete_cumsum_2d(
101-
self, n: int, b: int, long_index: bool
107+
self,
108+
n: int,
109+
b: int,
110+
index_types: Tuple[Type[object], Type[object]],
111+
device: torch.device,
102112
) -> None:
103-
index_dtype = torch.int64 if long_index else torch.int32
104-
105-
def test_asynchronous_complete_cumsum_2d_helper(x: torch.Tensor) -> None:
106-
np_index_dtype = np.int64 if long_index else np.int32
107-
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
108-
zeros = torch.zeros(b, 1)
109-
torch.testing.assert_close(
110-
torch.from_numpy(
111-
np.cumsum(
112-
torch.concat([zeros, x.cpu()], dim=1).numpy(), axis=1
113-
).astype(np_index_dtype)
114-
),
115-
zc.cpu(),
116-
)
117-
118-
x = torch.randint(low=0, high=100, size=(b, n)).type(index_dtype)
119-
# cpu test
120-
test_asynchronous_complete_cumsum_2d_helper(x)
121-
if gpu_available:
122-
# gpu test
123-
test_asynchronous_complete_cumsum_2d_helper(x.cuda())
113+
(pt_index_dtype, np_index_dtype) = index_types
114+
115+
# The CPU variants of asynchronous_*_cumsum support floats, since some
116+
# downstream tests appear to be relying on this behavior. As such, the
117+
# test is disabled for GPU + float test cases.
118+
if device == torch.device("cuda") and pt_index_dtype is torch.float32:
119+
return
120+
121+
# pyre-ignore-errors[16]
122+
x = torch.randint(low=0, high=100, size=(b, n)).type(pt_index_dtype).to(device)
123+
124+
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
125+
zeros = torch.zeros(b, 1)
126+
torch.testing.assert_close(
127+
torch.from_numpy(
128+
np.cumsum(torch.concat([zeros, x.cpu()], dim=1).numpy(), axis=1).astype(
129+
np_index_dtype
130+
)
131+
),
132+
zc.cpu(),
133+
)
124134

125135

126136
extend_test_class(CumSumTest)

0 commit comments

Comments
 (0)