Skip to content

Commit affd1b6

Browse files
jwfrommhwu36
andauthored
[EVT] Add support for Row/Col broadcast PtrArray (NVIDIA#2033)
* Add group support to EVT row/col broadcast. * small modifications --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
1 parent 6f55278 commit affd1b6

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -972,14 +972,20 @@ compute_row_broadcast_stages() {
972972
template<
973973
int Stages,
974974
class CtaTileShapeMNK,
975-
class ElementInput,
976-
class ElementCompute = ElementInput,
975+
class ElementInput_,
976+
class ElementCompute = cute::remove_pointer_t<ElementInput_>,
977977
class StrideMNL_ = Stride<_0,_1,_0>,
978-
int Alignment = 128 / sizeof_bits_v<ElementInput>,
978+
int Alignment = 128 / sizeof_bits_v<cute::remove_pointer_t<ElementInput_>>,
979979
bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
980980
>
981981
struct Sm90RowBroadcast {
982982
using StrideMNL = StrideMNL_;
983+
// Get base element input type.
984+
using ElementInput = cute::remove_pointer_t<ElementInput_>;
985+
// Check if input is an array of pointers.
986+
static constexpr bool IsArrayOfPointers = is_same_v<ElementInput*, ElementInput_>;
987+
using PtrRowType = cute::conditional_t<IsArrayOfPointers, ElementInput const* const*, ElementInput const*>;
988+
983989
static_assert(Stages == 0, "Row broadcast doesn't support smem pipelining");
984990

985991
static constexpr bool IsDynamicBroadcast = is_same_v<remove_cvref_t<decltype(get<1>(StrideMNL{}))>, bool>; // row vector or scalar broadcast
@@ -991,7 +997,7 @@ struct Sm90RowBroadcast {
991997
};
992998

993999
struct Arguments {
994-
ElementInput const* ptr_row = nullptr;
1000+
PtrRowType ptr_row = nullptr;
9951001
ElementInput null_default = ElementInput(0);
9961002
StrideMNL dRow = {};
9971003
};
@@ -1036,7 +1042,7 @@ struct Sm90RowBroadcast {
10361042
is_zero_ = params.null_default == ElementCompute(0);
10371043
}
10381044
// Dynamic non-batched scalar broadcast
1039-
else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) {
1045+
else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0) && !IsArrayOfPointers) {
10401046
is_zero_ = params.ptr_row[0] == ElementInput(0);
10411047
}
10421048
}
@@ -1183,7 +1189,13 @@ struct Sm90RowBroadcast {
11831189

11841190
auto layout_M = make_layout(M, repeat_like(M, _0{}));
11851191
auto layout_L = make_layout(L, get<2>(params.dRow));
1186-
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_layout(layout_M,layout_N,layout_L));
1192+
ElementInput const* ptr_row;
1193+
if constexpr(IsArrayOfPointers) {
1194+
ptr_row = params.ptr_row[l];
1195+
} else {
1196+
ptr_row = params.ptr_row;
1197+
}
1198+
Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L));
11871199
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
11881200
Tensor sRow = make_tensor(make_smem_ptr(smem),
11891201
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
@@ -1220,14 +1232,20 @@ struct Sm90RowBroadcast {
12201232
template<
12211233
int Stages,
12221234
class CtaTileShapeMNK,
1223-
class ElementInput,
1224-
class ElementCompute = ElementInput,
1235+
class ElementInput_,
1236+
class ElementCompute = cute::remove_pointer_t<ElementInput_>,
12251237
class StrideMNL_ = Stride<_1,_0,_0>,
1226-
int Alignment = 128 / sizeof_bits_v<ElementInput>,
1238+
int Alignment = 128 / sizeof_bits_v<cute::remove_pointer_t<ElementInput_>>,
12271239
bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
12281240
>
12291241
struct Sm90ColBroadcast {
12301242
using StrideMNL = StrideMNL_;
1243+
// Get base element input type.
1244+
using ElementInput = cute::remove_pointer_t<ElementInput_>;
1245+
// Check if input is an array of pointers.
1246+
static constexpr bool IsArrayOfPointers = is_same_v<ElementInput*, ElementInput_>;
1247+
using PtrColType = cute::conditional_t<IsArrayOfPointers, ElementInput const* const*, ElementInput const*>;
1248+
12311249
static_assert(Stages == 0, "Column broadcast doesn't support smem pipelining");
12321250

12331251
static constexpr bool IsDynamicBroadcast = is_same_v<remove_cvref_t<decltype(get<0>(StrideMNL{}))>, bool>; // Column vector or scalar broadcast
@@ -1238,13 +1256,13 @@ struct Sm90ColBroadcast {
12381256
struct SharedStorage { };
12391257

12401258
struct Arguments {
1241-
ElementInput const* ptr_col = nullptr;
1259+
PtrColType ptr_col = nullptr;
12421260
ElementInput null_default = ElementInput(0);
12431261
StrideMNL dCol = {};
12441262
};
12451263

12461264
struct Params {
1247-
ElementInput const* ptr_col = nullptr;
1265+
PtrColType ptr_col = nullptr;
12481266
ElementCompute null_default = ElementCompute(0);
12491267
StrideMNL dCol = {};
12501268
};
@@ -1301,7 +1319,7 @@ struct Sm90ColBroadcast {
13011319
is_zero_ = params.null_default == ElementCompute(0);
13021320
}
13031321
// Dynamic non-batched scalar broadcast
1304-
else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) {
1322+
else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0) && !IsArrayOfPointers) {
13051323
is_zero_ = params.ptr_col[0] == ElementInput(0);
13061324
}
13071325
}
@@ -1398,6 +1416,7 @@ struct Sm90ColBroadcast {
13981416
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
13991417

14001418
auto [M, N, K, L] = args.problem_shape_mnkl;
1419+
auto [m, n, k, l] = args.tile_coord_mnkl;
14011420
auto layout_M = [&] () CUTLASS_LAMBDA_FUNC_INLINE {
14021421
auto shape_M = get<0>(args.problem_shape_mnkl);
14031422
if constexpr (IsDynamicBroadcast) {
@@ -1416,11 +1435,17 @@ struct Sm90ColBroadcast {
14161435

14171436
auto layout_N = make_layout(N, repeat_like(N, _0{}));
14181437
auto layout_L = make_layout(L, get<2>(params.dCol));
1419-
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(layout_M,layout_N,layout_L));
1438+
ElementInput const* ptr_col;
1439+
if constexpr(IsArrayOfPointers) {
1440+
ptr_col = params.ptr_col[l];
1441+
} else {
1442+
ptr_col = params.ptr_col;
1443+
}
1444+
Tensor mCol = make_tensor(make_gmem_ptr(ptr_col), make_layout(layout_M,layout_N,layout_L));
14201445
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
14211446
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
14221447

1423-
Tensor mCol_static = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(make_layout(M),layout_N,layout_L));
1448+
Tensor mCol_static = make_tensor(make_gmem_ptr(ptr_col), make_layout(make_layout(M),layout_N,layout_L));
14241449
Tensor tCgCol_static = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
14251450
mCol_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
14261451
Tensor tCrCol = make_tensor_like<ElementCompute>(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)

0 commit comments

Comments
 (0)