5
5
* This source code is licensed under the BSD-style license found in the
6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
-
9
8
#include < ATen/ATen.h>
10
9
#include < ATen/cuda/CUDAContext.h>
11
10
@@ -23,30 +22,32 @@ namespace fbgemm_gpu {
23
22
24
23
namespace {
25
24
26
- template <int kBlkN , class DataType , class SmemLayout >
25
+ template <int kBlkNOrM , class DataType , class IndexType , class SmemLayout >
27
26
struct SharedStorage {
28
27
static constexpr int kPipeMax = cute::size<0 >(SmemLayout{});
29
28
static constexpr int kTmaAlignment = 128 ;
30
29
static constexpr int kMbarAlignemnt = 8 ;
31
30
32
- cute::array_aligned<int32_t , kBlkN > index;
31
+ cute::array_aligned<IndexType, kBlkNOrM > index;
33
32
cute::array_aligned<DataType, cute::cosize_v<SmemLayout>, kTmaAlignment > data;
34
33
35
34
CUTE_ALIGNAS (kMbarAlignemnt ) uint64_t tma_load_barrier[kPipeMax ];
36
35
};
37
36
38
37
template <
38
+ bool IsGather,
39
39
class ProblemShape ,
40
40
class TileShape ,
41
41
class DataType ,
42
+ class IndexType ,
42
43
class SmemLayout ,
43
44
class TmaLoad ,
44
45
class TmaStore >
45
- __global__ static void gather_along_first_dim_kernel (
46
+ __global__ static void gather_or_scatter_along_first_dim_kernel (
46
47
ProblemShape problem_shape,
47
48
TileShape tile_shape,
48
49
CUTLASS_GRID_CONSTANT TmaLoad const tma_load_input,
49
- const int32_t * index,
50
+ const IndexType * index,
50
51
CUTLASS_GRID_CONSTANT TmaStore const tma_store_output) {
51
52
// Input shape: A [M, K]
52
53
// Output shape: B [N, K]
@@ -55,23 +56,24 @@ __global__ static void gather_along_first_dim_kernel(
55
56
int K = cute::get<2 >(problem_shape);
56
57
57
58
static_assert (cute::is_static<TileShape>::value);
58
- constexpr int kBlkN = cute::size<0 >(tile_shape);
59
+ constexpr int kBlkNOrM = cute::size<0 >(tile_shape);
59
60
constexpr int kBlkK = cute::size<1 >(tile_shape);
60
61
61
- using SmemT = SharedStorage<kBlkN , DataType, SmemLayout>;
62
+ using SmemT = SharedStorage<kBlkNOrM , DataType, IndexType , SmemLayout>;
62
63
constexpr int kPipeMax = SmemT::kPipeMax ;
63
64
64
65
extern __shared__ char smem_raw[];
65
66
SmemT& smem = *reinterpret_cast <SmemT*>(smem_raw);
66
67
67
- const int n_offset = blockIdx .x * kBlkN ;
68
- if (n_offset >= N) {
68
+ int indexing_dim = IsGather ? N : M;
69
+ const int n_or_m_offset = blockIdx .x * kBlkNOrM ;
70
+ if (n_or_m_offset >= indexing_dim) {
69
71
return ;
70
72
}
71
73
72
74
// Straight-forward direct global read of indices.
73
- if (threadIdx .x < kBlkN && n_offset + threadIdx .x < N ) {
74
- smem.index [threadIdx .x ] = index[n_offset + threadIdx .x ];
75
+ if (threadIdx .x < kBlkNOrM && n_or_m_offset + threadIdx .x < indexing_dim ) {
76
+ smem.index [threadIdx .x ] = index[n_or_m_offset + threadIdx .x ];
75
77
}
76
78
__syncthreads ();
77
79
@@ -85,17 +87,24 @@ __global__ static void gather_along_first_dim_kernel(
85
87
86
88
constexpr int kTmaTransactionBytes = kBlkK * sizeof (DataType);
87
89
const int kNumKTiles = ((K + kBlkK - 1 ) / kBlkK );
88
- const int kNumNKTiles = kBlkN * kNumKTiles ;
89
- const int kNumIterations = kNumNKTiles + kPipeMax - 1 ;
90
+ const int kNumNOrMKTiles = kBlkNOrM * kNumKTiles ;
91
+ const int kNumIterations = kNumNOrMKTiles + kPipeMax - 1 ;
90
92
91
93
for (int iteration = 0 ; iteration < kNumIterations ; ++iteration) {
92
94
// Load.
93
- if (iteration < kNumNKTiles ) {
95
+ if (iteration < kNumNOrMKTiles ) {
94
96
int load_pipe = iteration % kPipeMax ;
95
97
96
- int n = iteration / kNumKTiles ;
97
- int k = iteration % kNumKTiles ;
98
- int m = smem.index [n];
98
+ int m, n, k;
99
+ if constexpr (IsGather) {
100
+ n = iteration / kNumKTiles ;
101
+ k = iteration % kNumKTiles ;
102
+ m = smem.index [n];
103
+ } else {
104
+ m = iteration / kNumKTiles + n_or_m_offset;
105
+ k = iteration % kNumKTiles ;
106
+ // n is not needed here
107
+ }
99
108
100
109
cute::tma_store_wait<kPipeMax - 1 >();
101
110
@@ -125,15 +134,23 @@ __global__ static void gather_along_first_dim_kernel(
125
134
int processing_index = iteration - kPipeMax + 1 ;
126
135
int store_pipe = processing_index % kPipeMax ;
127
136
128
- int n = processing_index / kNumKTiles ;
129
- int k = processing_index % kNumKTiles ;
137
+ int m, n, k;
138
+ if constexpr (IsGather) {
139
+ n = processing_index / kNumKTiles + n_or_m_offset;
140
+ k = processing_index % kNumKTiles ;
141
+ // m is not needed here
142
+ } else {
143
+ m = processing_index / kNumKTiles ;
144
+ k = processing_index % kNumKTiles ;
145
+ n = smem.index [m];
146
+ }
130
147
131
148
cute::wait_barrier (smem.tma_load_barrier [store_pipe], 0 );
132
149
133
150
cute::Tensor tAgB = cute::local_tile (
134
151
gB ,
135
152
cute::Tile<cute::_1, cute::Int<kBlkK >>{},
136
- cute::make_coord (n + n_offset , k));
153
+ cute::make_coord (n, k));
137
154
cute::Tensor tAsA = cute::local_tile (
138
155
sA ,
139
156
cute::Tile<cute::_1, cute::Int<kBlkK >>{},
@@ -151,45 +168,60 @@ __global__ static void gather_along_first_dim_kernel(
151
168
cute::tma_store_wait<0 >();
152
169
}
153
170
154
- } // namespace
171
+ template <class T >
172
+ struct TorchDTypeTrait {};
155
173
156
- // TODO(shikaili): Templatize it and make it supports more configurations.
157
- at::Tensor gather_along_first_dim (at::Tensor data, at::Tensor index) {
158
- using DataType = cutlass::bfloat16_t ;
159
- constexpr auto kDataTypeEnum = at::kBFloat16 ;
160
- using IndexType = int32_t ;
161
- constexpr auto kIndexTypeEnum = at::kInt ;
162
- constexpr int kTmaGmemAlignment = 16 ;
174
+ template <>
175
+ struct TorchDTypeTrait <cutlass::bfloat16_t > {
176
+ static auto dtype () {
177
+ return at::kBFloat16 ;
178
+ };
179
+ };
163
180
164
- bool compatible = (data.dtype () == kDataTypeEnum && data.is_contiguous () &&
165
- data.dim () == 2 ) &&
166
- (index.dtype () == kIndexTypeEnum && index.is_contiguous () &&
167
- index.dim () == 1 ) &&
168
- (data.size (1 ) * sizeof (DataType) % kTmaGmemAlignment == 0 );
181
+ template <>
182
+ struct TorchDTypeTrait <int32_t > {
183
+ static auto dtype () {
184
+ return at::kInt ;
185
+ };
186
+ };
169
187
170
- if (!compatible) {
171
- return at::index_select (data, 0 , index);
172
- }
188
+ template <>
189
+ struct TorchDTypeTrait <int64_t > {
190
+ static auto dtype () {
191
+ return at::kLong ;
192
+ };
193
+ };
173
194
174
- const int M = data.size (0 );
175
- const int K = data.size (1 );
176
- const int N = index.size (0 );
195
+ template <
196
+ bool IsGather,
197
+ class DataType ,
198
+ class IndexType ,
199
+ class TMAStoreInst = cute::SM90_TMA_STORE>
200
+ void gather_or_scatter_along_first_dim (
201
+ at::Tensor src,
202
+ at::Tensor index,
203
+ at::Tensor dst) {
204
+ assert (src.dtype () == TorchDTypeTrait<DataType>::dtype ());
205
+ assert (dst.dtype () == TorchDTypeTrait<DataType>::dtype ());
206
+ assert (index.dtype () == TorchDTypeTrait<IndexType>::dtype ());
207
+
208
+ const int M = src.size (0 );
209
+ const int K = src.size (1 );
210
+ const int N = dst.size (0 );
177
211
178
212
auto src_gmem_layout =
179
213
cute::make_layout (cute::make_shape (M, K), cute::make_stride (K, 1 ));
180
214
auto src_gmem_tensor = cute::make_tensor (
181
- cute::make_gmem_ptr (reinterpret_cast <DataType*>(data .data_ptr ())),
215
+ cute::make_gmem_ptr (reinterpret_cast <DataType*>(src .data_ptr ())),
182
216
src_gmem_layout);
183
217
184
- at::Tensor output = at::empty (
185
- {N, K}, at::TensorOptions ().dtype (at::kBFloat16 ).device (data.device ()));
186
218
auto dst_gmem_layout =
187
219
cute::make_layout (cute::make_shape (N, K), cute::make_stride (K, 1 ));
188
220
auto dst_gmem_tensor = cute::make_tensor (
189
- cute::make_gmem_ptr (reinterpret_cast <DataType*>(output .data_ptr ())),
221
+ cute::make_gmem_ptr (reinterpret_cast <DataType*>(dst .data_ptr ())),
190
222
dst_gmem_layout);
191
223
192
- constexpr int kBlkN = 1 ;
224
+ constexpr int kBlkNOrM = 1 ;
193
225
constexpr int kBlkK = 256 ;
194
226
constexpr int kPipeMax = 4 ;
195
227
@@ -199,16 +231,15 @@ at::Tensor gather_along_first_dim(at::Tensor data, at::Tensor index) {
199
231
auto tma_load = cute::make_tma_copy (
200
232
cute::SM90_TMA_LOAD{}, src_gmem_tensor, smem_layout (0 , cute::_, cute::_));
201
233
auto tma_store = cute::make_tma_copy (
202
- cute::SM90_TMA_STORE{},
203
- dst_gmem_tensor,
204
- smem_layout (0 , cute::_, cute::_));
234
+ TMAStoreInst{}, dst_gmem_tensor, smem_layout (0 , cute::_, cute::_));
205
235
206
236
auto problem_shape = cute::make_shape (M, N, K);
207
- auto tile_shape = cute::make_shape (cute::Int<kBlkN >{}, cute::Int<kBlkK >{});
237
+ auto tile_shape = cute::make_shape (cute::Int<kBlkNOrM >{}, cute::Int<kBlkK >{});
208
238
209
- using SmemT = SharedStorage<kBlkN , DataType, decltype (smem_layout)>;
239
+ using SmemT =
240
+ SharedStorage<kBlkNOrM , DataType, IndexType, decltype (smem_layout)>;
210
241
211
- int num_ctas = (N + kBlkN - 1 ) / kBlkN ;
242
+ int num_ctas = ((IsGather ? N : M) + kBlkNOrM - 1 ) / kBlkNOrM ;
212
243
dim3 grid_dims (num_ctas, 1 , 1 );
213
244
dim3 block_dims (32 , 1 , 1 );
214
245
dim3 cluster_dims (1 , 1 , 1 );
@@ -217,10 +248,12 @@ at::Tensor gather_along_first_dim(at::Tensor data, at::Tensor index) {
217
248
218
249
cutlass::ClusterLaunchParams launch_params{
219
250
grid_dims, block_dims, cluster_dims, smem_size, stream};
220
- void * kernel = (void *)gather_along_first_dim_kernel<
251
+ void * kernel = (void *)gather_or_scatter_along_first_dim_kernel<
252
+ IsGather,
221
253
decltype (problem_shape),
222
254
decltype (tile_shape),
223
255
DataType,
256
+ IndexType,
224
257
decltype (smem_layout),
225
258
decltype (tma_load),
226
259
decltype (tma_store)>;
@@ -242,8 +275,78 @@ at::Tensor gather_along_first_dim(at::Tensor data, at::Tensor index) {
242
275
cudaError_t error = cudaGetLastError ();
243
276
CUTE_ERROR_EXIT (error);
244
277
}
278
+ }
279
+
280
+ } // namespace
281
+
282
+ at::Tensor gather_along_first_dim (at::Tensor data, at::Tensor index) {
283
+ constexpr int kTmaGmemAlignment = 16 ;
284
+
285
+ if (data.is_contiguous () && data.dim () == 2 && index.is_contiguous () &&
286
+ index.dim () == 1 ) {
287
+ const int M = data.size (0 );
288
+ const int K = data.size (1 );
289
+ const int N = index.size (0 );
290
+ // TODO(shikaili): Make it supports more configurations.
291
+ if (data.dtype () == at::kBFloat16 &&
292
+ (K * sizeof (cutlass::bfloat16_t ) % kTmaGmemAlignment == 0 )) {
293
+ at::Tensor output = at::empty (
294
+ {N, K},
295
+ at::TensorOptions ().dtype (at::kBFloat16 ).device (data.device ()));
296
+ if (index.dtype () == at::kInt ) {
297
+ gather_or_scatter_along_first_dim<
298
+ true ,
299
+ cutlass::bfloat16_t ,
300
+ int32_t ,
301
+ cute::SM90_TMA_STORE>(data, index, output);
302
+ return output;
303
+ } else if (index.dtype () == at::kLong ) {
304
+ gather_or_scatter_along_first_dim<
305
+ true ,
306
+ cutlass::bfloat16_t ,
307
+ int64_t ,
308
+ cute::SM90_TMA_STORE>(data, index, output);
309
+ return output;
310
+ }
311
+ }
312
+ }
313
+ return at::index_select (data, 0 , index);
314
+ }
315
+
316
+ void scatter_add_along_first_dim (
317
+ at::Tensor dst,
318
+ at::Tensor src,
319
+ at::Tensor index) {
320
+ constexpr int kTmaGmemAlignment = 16 ;
321
+
322
+ if (dst.is_contiguous () && dst.dim () == 2 && src.is_contiguous () &&
323
+ src.dim () == 2 && index.is_contiguous () && index.dim () == 1 ) {
324
+ const int M = src.size (0 );
325
+ const int K = src.size (1 );
326
+ const int N = index.size (0 );
327
+ assert (dst.size (1 ) == K);
328
+ // TODO(shikaili): Make it supports more configurations.
329
+ if (dst.dtype () == at::kBFloat16 && src.dtype () == at::kBFloat16 &&
330
+ (K * sizeof (cutlass::bfloat16_t ) % kTmaGmemAlignment == 0 )) {
331
+ if (index.dtype () == at::kInt ) {
332
+ gather_or_scatter_along_first_dim<
333
+ false ,
334
+ cutlass::bfloat16_t ,
335
+ int32_t ,
336
+ cute::SM90_TMA_REDUCE_ADD>(src, index, dst);
337
+ return ;
338
+ } else if (index.dtype () == at::kLong ) {
339
+ gather_or_scatter_along_first_dim<
340
+ false ,
341
+ cutlass::bfloat16_t ,
342
+ int64_t ,
343
+ cute::SM90_TMA_REDUCE_ADD>(src, index, dst);
344
+ }
345
+ }
346
+ }
245
347
246
- return output;
348
+ const int K = src.size (1 );
349
+ dst.scatter_add_ (0 , index.to (at::kLong ).unsqueeze (1 ).expand ({-1 , K}), src);
247
350
}
248
351
249
352
#endif
0 commit comments