@@ -22,7 +22,7 @@ struct identity {
22
22
T operator ()(T lhs) const { return lhs; }
23
23
};
24
24
25
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor >
25
+ template <typename ElementAcc, typename ElementD, typename TileShape >
26
26
struct TrivialEpilogue {
27
27
private:
28
28
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
@@ -44,32 +44,30 @@ struct TrivialEpilogue {
44
44
* This class provides the common load descriptors for the
45
45
* ScaledEpilogue[...] classes
46
46
*/
47
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor >
47
+ template <typename ElementAcc, typename ElementD, typename TileShape >
48
48
struct ScaledEpilogueBase {
49
49
protected:
50
50
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
51
51
52
52
template <typename T>
53
53
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
54
- 0 /* Stages*/ , typename EpilogueDescriptor::TileShape, T,
55
- Stride<Int<1 >, Int<0 >, Int<0 >>>;
54
+ 0 /* Stages*/ , TileShape, T, Stride<Int<1 >, Int<0 >, Int<0 >>>;
56
55
57
56
template <typename T>
58
57
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
59
- 0 /* Stages*/ , typename EpilogueDescriptor::TileShape, T,
60
- Stride<Int<0 >, Int<1 >, Int<0 >>>;
58
+ 0 /* Stages*/ , TileShape, T, Stride<Int<0 >, Int<1 >, Int<0 >>>;
61
59
62
60
// Don't want to support nullptr by default
63
61
template <typename T, bool EnableNullPtr = false >
64
62
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
65
- 0 /* Stages*/ , typename EpilogueDescriptor:: TileShape, T, T,
66
- Stride<Int< 1 >, Int< 0 >, Int< 0 >>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
63
+ 0 /* Stages*/ , TileShape, T, T, Stride<Int< 1 >, Int< 0 >, Int< 0 >> ,
64
+ 128 / sizeof_bits_v<T>, EnableNullPtr>;
67
65
68
66
// Don't want to support nullptr by default
69
67
template <typename T, bool EnableNullPtr = false >
70
68
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
71
- 0 /* Stages*/ , typename EpilogueDescriptor:: TileShape, T, T,
72
- Stride<Int< 0 >, Int< 1 >, Int< 0 >>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
69
+ 0 /* Stages*/ , TileShape, T, T, Stride<Int< 0 >, Int< 1 >, Int< 0 >> ,
70
+ 128 / sizeof_bits_v<T>, EnableNullPtr>;
73
71
74
72
// This utility function constructs the arguments for the load descriptors
75
73
// from a tensor. It can handle both row and column, as well as row/column or
@@ -116,11 +114,11 @@ struct ScaledEpilogueBase {
116
114
the A and B operands respectively. These scales may be either per-tensor or
117
115
per row or column.
118
116
*/
119
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor >
117
+ template <typename ElementAcc, typename ElementD, typename TileShape >
120
118
struct ScaledEpilogue
121
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor > {
119
+ : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape > {
122
120
private:
123
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor >;
121
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape >;
124
122
using Accum = typename SUPER::Accum;
125
123
using ScaleA = typename SUPER::template ColOrScalarLoad<float >;
126
124
using ScaleB = typename SUPER::template RowOrScalarLoad<float >;
@@ -160,11 +158,11 @@ struct ScaledEpilogue
160
158
* The bias tensor must be per-output channel.
161
159
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
162
160
*/
163
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor >
161
+ template <typename ElementAcc, typename ElementD, typename TileShape >
164
162
struct ScaledEpilogueBias
165
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor > {
163
+ : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape > {
166
164
private:
167
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor >;
165
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape >;
168
166
using Accum = typename SUPER::Accum;
169
167
using ScaleA = typename SUPER::template ColOrScalarLoad<float >;
170
168
using ScaleB = typename SUPER::template RowOrScalarLoad<float >;
@@ -203,11 +201,11 @@ struct ScaledEpilogueBias
203
201
* bias is a column vector instead of a row vector. Useful e.g. if we are
204
202
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
205
203
*/
206
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor >
204
+ template <typename ElementAcc, typename ElementD, typename TileShape >
207
205
struct ScaledEpilogueColumnBias
208
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor > {
206
+ : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape > {
209
207
private:
210
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor >;
208
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape >;
211
209
using Accum = typename SUPER::Accum;
212
210
using ScaleA = typename SUPER::template ColOrScalarLoad<float >;
213
211
using ScaleB = typename SUPER::template RowOrScalarLoad<float >;
@@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias
249
247
*
250
248
* This epilogue also supports bias, which remains per-channel.
251
249
*/
252
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor >
250
+ template <typename ElementAcc, typename ElementD, typename TileShape >
253
251
struct ScaledEpilogueBiasAzp
254
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor > {
252
+ : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape > {
255
253
private:
256
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor >;
254
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape >;
257
255
using Accum = typename SUPER::Accum;
258
256
using ScaleA = typename SUPER::template ColOrScalarLoad<float >;
259
257
using ScaleB = typename SUPER::template RowOrScalarLoad<float >;
@@ -313,11 +311,11 @@ struct ScaledEpilogueBiasAzp
313
311
*
314
312
* This epilogue also supports bias, which remains per-channel.
315
313
*/
316
- template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor >
314
+ template <typename ElementAcc, typename ElementD, typename TileShape >
317
315
struct ScaledEpilogueBiasAzpToken
318
- : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor > {
316
+ : private ScaledEpilogueBase<ElementAcc, ElementD, TileShape > {
319
317
private:
320
- using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor >;
318
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape >;
321
319
using Accum = typename SUPER::Accum;
322
320
using ScaleA = typename SUPER::template ColOrScalarLoad<float >;
323
321
using ScaleB = typename SUPER::template RowOrScalarLoad<float >;
0 commit comments