@@ -972,14 +972,20 @@ compute_row_broadcast_stages() {
972972template <
973973 int Stages,
974974 class CtaTileShapeMNK ,
975- class ElementInput ,
976- class ElementCompute = ElementInput ,
975+ class ElementInput_ ,
976+ class ElementCompute = cute:: remove_pointer_t <ElementInput_> ,
977977 class StrideMNL_ = Stride<_0,_1,_0>,
978- int Alignment = 128 / sizeof_bits_v<ElementInput >,
978+ int Alignment = 128 / sizeof_bits_v<cute:: remove_pointer_t <ElementInput_> >,
979979 bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
980980>
981981struct Sm90RowBroadcast {
982982 using StrideMNL = StrideMNL_;
983+ // Get base element input type.
984+ using ElementInput = cute::remove_pointer_t <ElementInput_>;
985+ // Check if input is an array of pointers.
986+ static constexpr bool IsArrayOfPointers = is_same_v<ElementInput*, ElementInput_>;
987+ using PtrRowType = cute::conditional_t <IsArrayOfPointers, ElementInput const * const *, ElementInput const *>;
988+
983989 static_assert (Stages == 0 , " Row broadcast doesn't support smem pipelining" );
984990
985991 static constexpr bool IsDynamicBroadcast = is_same_v<remove_cvref_t <decltype (get<1 >(StrideMNL{}))>, bool >; // row vector or scalar broadcast
@@ -991,7 +997,7 @@ struct Sm90RowBroadcast {
991997 };
992998
993999 struct Arguments {
994- ElementInput const * ptr_row = nullptr ;
1000+ PtrRowType ptr_row = nullptr ;
9951001 ElementInput null_default = ElementInput(0 );
9961002 StrideMNL dRow = {};
9971003 };
@@ -1036,7 +1042,7 @@ struct Sm90RowBroadcast {
10361042 is_zero_ = params.null_default == ElementCompute (0 );
10371043 }
10381044 // Dynamic non-batched scalar broadcast
1039- else if (IsDynamicBroadcast && stride_N == bool (0 ) && stride_L == repeat_like (stride_L, 0 )) {
1045+ else if (IsDynamicBroadcast && stride_N == bool (0 ) && stride_L == repeat_like (stride_L, 0 ) && !IsArrayOfPointers ) {
10401046 is_zero_ = params.ptr_row [0 ] == ElementInput (0 );
10411047 }
10421048 }
@@ -1183,7 +1189,13 @@ struct Sm90RowBroadcast {
11831189
11841190 auto layout_M = make_layout (M, repeat_like (M, _0{}));
11851191 auto layout_L = make_layout (L, get<2 >(params.dRow ));
1186- Tensor mRow = make_tensor (make_gmem_ptr (params.ptr_row ), make_layout (layout_M,layout_N,layout_L));
1192+ ElementInput const * ptr_row;
1193+ if constexpr (IsArrayOfPointers) {
1194+ ptr_row = params.ptr_row [l];
1195+ } else {
1196+ ptr_row = params.ptr_row ;
1197+ }
1198+ Tensor mRow = make_tensor (make_gmem_ptr (ptr_row), make_layout (layout_M,layout_N,layout_L));
11871199 Tensor gRow = local_tile (mRow (_,_,l), take<0 ,2 >(args.tile_shape_mnk ), make_coord (m, n)); // (CTA_M, CTA_N)
11881200 Tensor sRow = make_tensor (make_smem_ptr (smem),
11891201 make_shape (size<0 >(CtaTileShapeMNK{}), size<1 >(CtaTileShapeMNK{})), make_shape (_0{}, _1{})); // (CTA_M, CTA_N)
@@ -1220,14 +1232,20 @@ struct Sm90RowBroadcast {
12201232template <
12211233 int Stages,
12221234 class CtaTileShapeMNK ,
1223- class ElementInput ,
1224- class ElementCompute = ElementInput ,
1235+ class ElementInput_ ,
1236+ class ElementCompute = cute:: remove_pointer_t <ElementInput_> ,
12251237 class StrideMNL_ = Stride<_1,_0,_0>,
1226- int Alignment = 128 / sizeof_bits_v<ElementInput >,
1238+ int Alignment = 128 / sizeof_bits_v<cute:: remove_pointer_t <ElementInput_> >,
12271239 bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
12281240>
12291241struct Sm90ColBroadcast {
12301242 using StrideMNL = StrideMNL_;
1243+ // Get base element input type.
1244+ using ElementInput = cute::remove_pointer_t <ElementInput_>;
1245+ // Check if input is an array of pointers.
1246+ static constexpr bool IsArrayOfPointers = is_same_v<ElementInput*, ElementInput_>;
1247+ using PtrColType = cute::conditional_t <IsArrayOfPointers, ElementInput const * const *, ElementInput const *>;
1248+
12311249 static_assert (Stages == 0 , " Column broadcast doesn't support smem pipelining" );
12321250
12331251 static constexpr bool IsDynamicBroadcast = is_same_v<remove_cvref_t <decltype (get<0 >(StrideMNL{}))>, bool >; // Column vector or scalar broadcast
@@ -1238,13 +1256,13 @@ struct Sm90ColBroadcast {
12381256 struct SharedStorage { };
12391257
12401258 struct Arguments {
1241- ElementInput const * ptr_col = nullptr ;
1259+ PtrColType ptr_col = nullptr ;
12421260 ElementInput null_default = ElementInput(0 );
12431261 StrideMNL dCol = {};
12441262 };
12451263
12461264 struct Params {
1247- ElementInput const * ptr_col = nullptr ;
1265+ PtrColType ptr_col = nullptr ;
12481266 ElementCompute null_default = ElementCompute(0 );
12491267 StrideMNL dCol = {};
12501268 };
@@ -1301,7 +1319,7 @@ struct Sm90ColBroadcast {
13011319 is_zero_ = params.null_default == ElementCompute (0 );
13021320 }
13031321 // Dynamic non-batched scalar broadcast
1304- else if (IsDynamicBroadcast && stride_M == bool (0 ) && stride_L == repeat_like (stride_L, 0 )) {
1322+ else if (IsDynamicBroadcast && stride_M == bool (0 ) && stride_L == repeat_like (stride_L, 0 ) && !IsArrayOfPointers ) {
13051323 is_zero_ = params.ptr_col [0 ] == ElementInput (0 );
13061324 }
13071325 }
@@ -1398,6 +1416,7 @@ struct Sm90ColBroadcast {
13981416 get_consumer_store_callbacks (ConsumerStoreArgs<Args...> const & args) {
13991417
14001418 auto [M, N, K, L] = args.problem_shape_mnkl ;
1419+ auto [m, n, k, l] = args.tile_coord_mnkl ;
14011420 auto layout_M = [&] () CUTLASS_LAMBDA_FUNC_INLINE {
14021421 auto shape_M = get<0 >(args.problem_shape_mnkl );
14031422 if constexpr (IsDynamicBroadcast) {
@@ -1416,11 +1435,17 @@ struct Sm90ColBroadcast {
14161435
14171436 auto layout_N = make_layout (N, repeat_like (N, _0{}));
14181437 auto layout_L = make_layout (L, get<2 >(params.dCol ));
1419- Tensor mCol = make_tensor (make_gmem_ptr (params.ptr_col ), make_layout (layout_M,layout_N,layout_L));
1438+ ElementInput const * ptr_col;
1439+ if constexpr (IsArrayOfPointers) {
1440+ ptr_col = params.ptr_col [l];
1441+ } else {
1442+ ptr_col = params.ptr_col ;
1443+ }
1444+ Tensor mCol = make_tensor (make_gmem_ptr (ptr_col), make_layout (layout_M,layout_N,layout_L));
14201445 Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
14211446 mCol , args.tile_shape_mnk , args.tile_coord_mnkl , args.epi_tile , args.tiled_copy , args.thread_idx );
14221447
1423- Tensor mCol_static = make_tensor (make_gmem_ptr (params. ptr_col ), make_layout (make_layout (M),layout_N,layout_L));
1448+ Tensor mCol_static = make_tensor (make_gmem_ptr (ptr_col), make_layout (make_layout (M),layout_N,layout_L));
14241449 Tensor tCgCol_static = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
14251450 mCol_static , args.tile_shape_mnk , args.tile_coord_mnkl , args.epi_tile , args.tiled_copy , args.thread_idx );
14261451 Tensor tCrCol = make_tensor_like<ElementCompute>(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
0 commit comments