Skip to content

Commit d0e2dd1

Browse files
levendleefacebook-github-bot
authored andcommitted
torch.ops.fbgemm.scatter_add_along_first_dim. (#3720)
Summary: Pull Request resolved: #3720 X-link: facebookresearch/FBGEMM#804 TMA based scatter_add operation optimized for large shapes. Hyperparameters could be finetuned for better performance. However, the expected headroom is small. Differential Revision: D69957147
1 parent 0ac79a9 commit d0e2dd1

File tree

3 files changed

+221
-53
lines changed

3 files changed

+221
-53
lines changed

fbgemm_gpu/experimental/gen_ai/src/gather/gather.cpp renamed to fbgemm_gpu/experimental/gen_ai/src/gather_scatter/gather_scatter.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,21 @@ namespace fbgemm_gpu {
1515

1616
at::Tensor gather_along_first_dim(at::Tensor data, at::Tensor index);
1717

18+
void scatter_add_along_first_dim(
19+
at::Tensor dst,
20+
at::Tensor src,
21+
at::Tensor index);
22+
1823
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
19-
m.set_python_module("fbgemm_gpu.experimental.gen_ai.gather");
24+
m.set_python_module("fbgemm_gpu.experimental.gen_ai.gather_scatter");
2025
m.def("gather_along_first_dim(Tensor Data, Tensor Index) -> Tensor");
26+
m.def(
27+
"scatter_add_along_first_dim(Tensor Dst, Tensor Src, Tensor Index) -> ()");
2128
}
2229

2330
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
2431
m.impl("gather_along_first_dim", gather_along_first_dim);
32+
m.impl("scatter_add_along_first_dim", scatter_add_along_first_dim);
2533
}
2634

2735
#endif

fbgemm_gpu/experimental/gen_ai/src/gather/gather.cu renamed to fbgemm_gpu/experimental/gen_ai/src/gather_scatter/gather_scatter.cu

Lines changed: 155 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
77
*/
8-
98
#include <ATen/ATen.h>
109
#include <ATen/cuda/CUDAContext.h>
1110

@@ -23,30 +22,32 @@ namespace fbgemm_gpu {
2322

2423
namespace {
2524

26-
template <int kBlkN, class DataType, class SmemLayout>
25+
template <int kBlkNOrM, class DataType, class IndexType, class SmemLayout>
2726
struct SharedStorage {
2827
static constexpr int kPipeMax = cute::size<0>(SmemLayout{});
2928
static constexpr int kTmaAlignment = 128;
3029
static constexpr int kMbarAlignemnt = 8;
3130

32-
cute::array_aligned<int32_t, kBlkN> index;
31+
cute::array_aligned<IndexType, kBlkNOrM> index;
3332
cute::array_aligned<DataType, cute::cosize_v<SmemLayout>, kTmaAlignment> data;
3433

3534
CUTE_ALIGNAS(kMbarAlignemnt) uint64_t tma_load_barrier[kPipeMax];
3635
};
3736

3837
template <
38+
bool IsGather,
3939
class ProblemShape,
4040
class TileShape,
4141
class DataType,
42+
class IndexType,
4243
class SmemLayout,
4344
class TmaLoad,
4445
class TmaStore>
45-
__global__ static void gather_along_first_dim_kernel(
46+
__global__ static void gather_or_scatter_along_first_dim_kernel(
4647
ProblemShape problem_shape,
4748
TileShape tile_shape,
4849
CUTLASS_GRID_CONSTANT TmaLoad const tma_load_input,
49-
const int32_t* index,
50+
const IndexType* index,
5051
CUTLASS_GRID_CONSTANT TmaStore const tma_store_output) {
5152
// Input shape: A [M, K]
5253
// Output shape: B [N, K]
@@ -55,23 +56,24 @@ __global__ static void gather_along_first_dim_kernel(
5556
int K = cute::get<2>(problem_shape);
5657

5758
static_assert(cute::is_static<TileShape>::value);
58-
constexpr int kBlkN = cute::size<0>(tile_shape);
59+
constexpr int kBlkNOrM = cute::size<0>(tile_shape);
5960
constexpr int kBlkK = cute::size<1>(tile_shape);
6061

61-
using SmemT = SharedStorage<kBlkN, DataType, SmemLayout>;
62+
using SmemT = SharedStorage<kBlkNOrM, DataType, IndexType, SmemLayout>;
6263
constexpr int kPipeMax = SmemT::kPipeMax;
6364

6465
extern __shared__ char smem_raw[];
6566
SmemT& smem = *reinterpret_cast<SmemT*>(smem_raw);
6667

67-
const int n_offset = blockIdx.x * kBlkN;
68-
if (n_offset >= N) {
68+
int indexing_dim = IsGather ? N : M;
69+
const int n_or_m_offset = blockIdx.x * kBlkNOrM;
70+
if (n_or_m_offset >= indexing_dim) {
6971
return;
7072
}
7173

7274
// Straight-forward direct global read of indices.
73-
if (threadIdx.x < kBlkN && n_offset + threadIdx.x < N) {
74-
smem.index[threadIdx.x] = index[n_offset + threadIdx.x];
75+
if (threadIdx.x < kBlkNOrM && n_or_m_offset + threadIdx.x < indexing_dim) {
76+
smem.index[threadIdx.x] = index[n_or_m_offset + threadIdx.x];
7577
}
7678
__syncthreads();
7779

@@ -85,17 +87,24 @@ __global__ static void gather_along_first_dim_kernel(
8587

8688
constexpr int kTmaTransactionBytes = kBlkK * sizeof(DataType);
8789
const int kNumKTiles = ((K + kBlkK - 1) / kBlkK);
88-
const int kNumNKTiles = kBlkN * kNumKTiles;
89-
const int kNumIterations = kNumNKTiles + kPipeMax - 1;
90+
const int kNumNOrMKTiles = kBlkNOrM * kNumKTiles;
91+
const int kNumIterations = kNumNOrMKTiles + kPipeMax - 1;
9092

9193
for (int iteration = 0; iteration < kNumIterations; ++iteration) {
9294
// Load.
93-
if (iteration < kNumNKTiles) {
95+
if (iteration < kNumNOrMKTiles) {
9496
int load_pipe = iteration % kPipeMax;
9597

96-
int n = iteration / kNumKTiles;
97-
int k = iteration % kNumKTiles;
98-
int m = smem.index[n];
98+
int m, n, k;
99+
if constexpr (IsGather) {
100+
n = iteration / kNumKTiles;
101+
k = iteration % kNumKTiles;
102+
m = smem.index[n];
103+
} else {
104+
m = iteration / kNumKTiles + n_or_m_offset;
105+
k = iteration % kNumKTiles;
106+
// n is not needed here
107+
}
99108

100109
cute::tma_store_wait<kPipeMax - 1>();
101110

@@ -125,15 +134,23 @@ __global__ static void gather_along_first_dim_kernel(
125134
int processing_index = iteration - kPipeMax + 1;
126135
int store_pipe = processing_index % kPipeMax;
127136

128-
int n = processing_index / kNumKTiles;
129-
int k = processing_index % kNumKTiles;
137+
int m, n, k;
138+
if constexpr (IsGather) {
139+
n = processing_index / kNumKTiles + n_or_m_offset;
140+
k = processing_index % kNumKTiles;
141+
// m is not needed here
142+
} else {
143+
m = processing_index / kNumKTiles;
144+
k = processing_index % kNumKTiles;
145+
n = smem.index[m];
146+
}
130147

131148
cute::wait_barrier(smem.tma_load_barrier[store_pipe], 0);
132149

133150
cute::Tensor tAgB = cute::local_tile(
134151
gB,
135152
cute::Tile<cute::_1, cute::Int<kBlkK>>{},
136-
cute::make_coord(n + n_offset, k));
153+
cute::make_coord(n, k));
137154
cute::Tensor tAsA = cute::local_tile(
138155
sA,
139156
cute::Tile<cute::_1, cute::Int<kBlkK>>{},
@@ -151,45 +168,60 @@ __global__ static void gather_along_first_dim_kernel(
151168
cute::tma_store_wait<0>();
152169
}
153170

154-
} // namespace
171+
template <class T>
172+
struct TorchDTypeTrait {};
155173

156-
// TODO(shikaili): Templatize it and make it supports more configurations.
157-
at::Tensor gather_along_first_dim(at::Tensor data, at::Tensor index) {
158-
using DataType = cutlass::bfloat16_t;
159-
constexpr auto kDataTypeEnum = at::kBFloat16;
160-
using IndexType = int32_t;
161-
constexpr auto kIndexTypeEnum = at::kInt;
162-
constexpr int kTmaGmemAlignment = 16;
174+
template <>
175+
struct TorchDTypeTrait<cutlass::bfloat16_t> {
176+
static auto dtype() {
177+
return at::kBFloat16;
178+
};
179+
};
163180

164-
bool compatible = (data.dtype() == kDataTypeEnum && data.is_contiguous() &&
165-
data.dim() == 2) &&
166-
(index.dtype() == kIndexTypeEnum && index.is_contiguous() &&
167-
index.dim() == 1) &&
168-
(data.size(1) * sizeof(DataType) % kTmaGmemAlignment == 0);
181+
template <>
182+
struct TorchDTypeTrait<int32_t> {
183+
static auto dtype() {
184+
return at::kInt;
185+
};
186+
};
169187

170-
if (!compatible) {
171-
return at::index_select(data, 0, index);
172-
}
188+
template <>
189+
struct TorchDTypeTrait<int64_t> {
190+
static auto dtype() {
191+
return at::kLong;
192+
};
193+
};
173194

174-
const int M = data.size(0);
175-
const int K = data.size(1);
176-
const int N = index.size(0);
195+
template <
196+
bool IsGather,
197+
class DataType,
198+
class IndexType,
199+
class TMAStoreInst = cute::SM90_TMA_STORE>
200+
void gather_or_scatter_along_first_dim(
201+
at::Tensor src,
202+
at::Tensor index,
203+
at::Tensor dst) {
204+
assert(src.dtype() == TorchDTypeTrait<DataType>::dtype());
205+
assert(dst.dtype() == TorchDTypeTrait<DataType>::dtype());
206+
assert(index.dtype() == TorchDTypeTrait<IndexType>::dtype());
207+
208+
const int M = src.size(0);
209+
const int K = src.size(1);
210+
const int N = dst.size(0);
177211

178212
auto src_gmem_layout =
179213
cute::make_layout(cute::make_shape(M, K), cute::make_stride(K, 1));
180214
auto src_gmem_tensor = cute::make_tensor(
181-
cute::make_gmem_ptr(reinterpret_cast<DataType*>(data.data_ptr())),
215+
cute::make_gmem_ptr(reinterpret_cast<DataType*>(src.data_ptr())),
182216
src_gmem_layout);
183217

184-
at::Tensor output = at::empty(
185-
{N, K}, at::TensorOptions().dtype(at::kBFloat16).device(data.device()));
186218
auto dst_gmem_layout =
187219
cute::make_layout(cute::make_shape(N, K), cute::make_stride(K, 1));
188220
auto dst_gmem_tensor = cute::make_tensor(
189-
cute::make_gmem_ptr(reinterpret_cast<DataType*>(output.data_ptr())),
221+
cute::make_gmem_ptr(reinterpret_cast<DataType*>(dst.data_ptr())),
190222
dst_gmem_layout);
191223

192-
constexpr int kBlkN = 1;
224+
constexpr int kBlkNOrM = 1;
193225
constexpr int kBlkK = 256;
194226
constexpr int kPipeMax = 4;
195227

@@ -199,16 +231,15 @@ at::Tensor gather_along_first_dim(at::Tensor data, at::Tensor index) {
199231
auto tma_load = cute::make_tma_copy(
200232
cute::SM90_TMA_LOAD{}, src_gmem_tensor, smem_layout(0, cute::_, cute::_));
201233
auto tma_store = cute::make_tma_copy(
202-
cute::SM90_TMA_STORE{},
203-
dst_gmem_tensor,
204-
smem_layout(0, cute::_, cute::_));
234+
TMAStoreInst{}, dst_gmem_tensor, smem_layout(0, cute::_, cute::_));
205235

206236
auto problem_shape = cute::make_shape(M, N, K);
207-
auto tile_shape = cute::make_shape(cute::Int<kBlkN>{}, cute::Int<kBlkK>{});
237+
auto tile_shape = cute::make_shape(cute::Int<kBlkNOrM>{}, cute::Int<kBlkK>{});
208238

209-
using SmemT = SharedStorage<kBlkN, DataType, decltype(smem_layout)>;
239+
using SmemT =
240+
SharedStorage<kBlkNOrM, DataType, IndexType, decltype(smem_layout)>;
210241

211-
int num_ctas = (N + kBlkN - 1) / kBlkN;
242+
int num_ctas = ((IsGather ? N : M) + kBlkNOrM - 1) / kBlkNOrM;
212243
dim3 grid_dims(num_ctas, 1, 1);
213244
dim3 block_dims(32, 1, 1);
214245
dim3 cluster_dims(1, 1, 1);
@@ -217,10 +248,12 @@ at::Tensor gather_along_first_dim(at::Tensor data, at::Tensor index) {
217248

218249
cutlass::ClusterLaunchParams launch_params{
219250
grid_dims, block_dims, cluster_dims, smem_size, stream};
220-
void* kernel = (void*)gather_along_first_dim_kernel<
251+
void* kernel = (void*)gather_or_scatter_along_first_dim_kernel<
252+
IsGather,
221253
decltype(problem_shape),
222254
decltype(tile_shape),
223255
DataType,
256+
IndexType,
224257
decltype(smem_layout),
225258
decltype(tma_load),
226259
decltype(tma_store)>;
@@ -242,8 +275,78 @@ at::Tensor gather_along_first_dim(at::Tensor data, at::Tensor index) {
242275
cudaError_t error = cudaGetLastError();
243276
CUTE_ERROR_EXIT(error);
244277
}
278+
}
279+
280+
} // namespace
281+
282+
at::Tensor gather_along_first_dim(at::Tensor data, at::Tensor index) {
283+
constexpr int kTmaGmemAlignment = 16;
284+
285+
if (data.is_contiguous() && data.dim() == 2 && index.is_contiguous() &&
286+
index.dim() == 1) {
287+
const int M = data.size(0);
288+
const int K = data.size(1);
289+
const int N = index.size(0);
290+
// TODO(shikaili): Make it supports more configurations.
291+
if (data.dtype() == at::kBFloat16 &&
292+
(K * sizeof(cutlass::bfloat16_t) % kTmaGmemAlignment == 0)) {
293+
at::Tensor output = at::empty(
294+
{N, K},
295+
at::TensorOptions().dtype(at::kBFloat16).device(data.device()));
296+
if (index.dtype() == at::kInt) {
297+
gather_or_scatter_along_first_dim<
298+
true,
299+
cutlass::bfloat16_t,
300+
int32_t,
301+
cute::SM90_TMA_STORE>(data, index, output);
302+
return output;
303+
} else if (index.dtype() == at::kLong) {
304+
gather_or_scatter_along_first_dim<
305+
true,
306+
cutlass::bfloat16_t,
307+
int64_t,
308+
cute::SM90_TMA_STORE>(data, index, output);
309+
return output;
310+
}
311+
}
312+
}
313+
return at::index_select(data, 0, index);
314+
}
315+
316+
void scatter_add_along_first_dim(
317+
at::Tensor dst,
318+
at::Tensor src,
319+
at::Tensor index) {
320+
constexpr int kTmaGmemAlignment = 16;
321+
322+
if (dst.is_contiguous() && dst.dim() == 2 && src.is_contiguous() &&
323+
src.dim() == 2 && index.is_contiguous() && index.dim() == 1) {
324+
const int M = src.size(0);
325+
const int K = src.size(1);
326+
const int N = index.size(0);
327+
assert(dst.size(1) == K);
328+
// TODO(shikaili): Make it supports more configurations.
329+
if (dst.dtype() == at::kBFloat16 && src.dtype() == at::kBFloat16 &&
330+
(K * sizeof(cutlass::bfloat16_t) % kTmaGmemAlignment == 0)) {
331+
if (index.dtype() == at::kInt) {
332+
gather_or_scatter_along_first_dim<
333+
false,
334+
cutlass::bfloat16_t,
335+
int32_t,
336+
cute::SM90_TMA_REDUCE_ADD>(src, index, dst);
337+
return;
338+
} else if (index.dtype() == at::kLong) {
339+
gather_or_scatter_along_first_dim<
340+
false,
341+
cutlass::bfloat16_t,
342+
int64_t,
343+
cute::SM90_TMA_REDUCE_ADD>(src, index, dst);
344+
}
345+
}
346+
}
245347

246-
return output;
348+
const int K = src.size(1);
349+
dst.scatter_add_(0, index.to(at::kLong).unsqueeze(1).expand({-1, K}), src);
247350
}
248351

249352
#endif

0 commit comments

Comments
 (0)