@@ -105,27 +105,34 @@ using Cast = cutlass::epilogue::fusion::Sm90Compute<
105
105
DtypeEpilogue,
106
106
cutlass::FloatRoundStyle::round_to_nearest>;
107
107
108
- template <bool PingPong , bool FastAccum>
108
+ template <bool LargeTile , bool FastAccum>
109
109
struct Schedule ;
110
110
111
111
template <>
112
- struct Schedule </* PingPong =*/ false , /* FastAccum=*/ false > {
112
+ struct Schedule </* LargeTile =*/ false , /* FastAccum=*/ false > {
113
113
using type = cutlass::gemm::KernelTmaWarpSpecialized;
114
+ using epilogue_type = cutlass::epilogue::TmaWarpSpecialized;
114
115
};
115
116
116
117
template <>
117
- struct Schedule </* PingPong=*/ true , /* FastAccum=*/ false > {
118
- using type = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
118
+ struct Schedule </* LargeTile=*/ true , /* FastAccum=*/ false > {
119
+ // For a 128x128x128 tile with fastAccum = false, using
120
+ // pingpong schedule will lead to spilling, and WarpSpecialized w/o pingpong
121
+ // is slow
122
+ using type = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
123
+ using epilogue_type = cutlass::epilogue::TmaWarpSpecializedCooperative;
119
124
};
120
125
121
126
template <>
122
- struct Schedule </* PingPong =*/ false , /* FastAccum=*/ true > {
127
+ struct Schedule </* LargeTile =*/ false , /* FastAccum=*/ true > {
123
128
using type = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
129
+ using epilogue_type = cutlass::epilogue::TmaWarpSpecialized;
124
130
};
125
131
126
132
template <>
127
- struct Schedule </* PingPong =*/ true , /* FastAccum=*/ true > {
133
+ struct Schedule </* LargeTile =*/ true , /* FastAccum=*/ true > {
128
134
using type = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
135
+ using epilogue_type = cutlass::epilogue::TmaWarpSpecialized;
129
136
};
130
137
131
138
int ceildiv (int a, int b) {
@@ -140,7 +147,6 @@ int round_up_to_nearest_multiple(int a, int b) {
140
147
template <
141
148
typename TileShape,
142
149
typename ClusterShape,
143
- typename PingPong,
144
150
typename Transposed,
145
151
typename FastAccum,
146
152
typename DtypeA,
@@ -226,6 +232,8 @@ void f8f8bf16_rowwise_impl(
226
232
Bias,
227
233
AccumScale>>;
228
234
235
+ constexpr bool large_tile = std::is_same_v<TileShape, cute::Shape<cute::_128, cute::_128, cute::_128>>;
236
+
229
237
using CollectiveEpilogue =
230
238
typename cutlass::epilogue::collective::CollectiveBuilder<
231
239
ArchTag,
@@ -241,7 +249,7 @@ void f8f8bf16_rowwise_impl(
241
249
DtypeOutput,
242
250
LayoutOutput,
243
251
AlignmentOutput,
244
- cutlass::epilogue::TmaWarpSpecialized ,
252
+ typename Schedule<large_tile, FastAccum::value>::epilogue_type ,
245
253
EpilogueEVT>::CollectiveOp;
246
254
247
255
using CollectiveMainloop =
@@ -259,7 +267,7 @@ void f8f8bf16_rowwise_impl(
259
267
ClusterShape,
260
268
cutlass::gemm::collective::StageCountAutoCarveout<static_cast <int >(
261
269
sizeof (typename CollectiveEpilogue::SharedStorage))>,
262
- typename Schedule<PingPong::value , FastAccum::value>::type>::
270
+ typename Schedule<large_tile , FastAccum::value>::type>::
263
271
CollectiveOp;
264
272
265
273
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
@@ -370,13 +378,11 @@ void dispatch_fp8_rowwise_kernel_on_tile_size(
370
378
return f8f8bf16_rowwise_impl<
371
379
/* TileShape=*/ cute::Shape<cute::_64, cute::_128, cute::_128>,
372
380
ClusterShape,
373
- /* PingPong=*/ std::false_type,
374
381
Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
375
382
} else {
376
383
return f8f8bf16_rowwise_impl<
377
384
/* TileShape=*/ cute::Shape<cute::_128, cute::_128, cute::_128>,
378
385
ClusterShape,
379
- /* PingPong=*/ std::true_type,
380
386
Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
381
387
}
382
388
}
0 commit comments