@@ -77,31 +77,45 @@ struct enable_sm89_to_sm90 : Kernel {
77
77
};
78
78
79
79
/*
80
- This epilogue function defines a quantized GEMM operation similar to
81
- torch._scaled_mm.
82
-
83
- A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
84
- per-row. B can be quantized per-tensor or per-column.
85
- Any combination of per-tensor and per-row or column is supported.
86
- A and B must have symmetric quantization (zero point == 0).
87
-
88
- So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
89
- scales are applied elementwise with numpy-style broadcasting.
90
-
91
- ScaleA and ScaleB define the epilogue functions that apply the scales for
92
- the A and B operands respectively. These scales may be either per-tensor or
93
- per row or column.
94
- */
80
+ * This class provides the common ScaleA and ScaleB descriptors for the
81
+ * ScaledEpilogue and ScaledEpilogueBias classes.
82
+ */
95
83
template <typename ElementD, typename OutputTileThreadMap>
96
- struct ScaledEpilogue {
97
- private :
84
+ struct ScaledEpilogueBase {
85
+ protected :
98
86
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
99
87
100
88
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
101
89
OutputTileThreadMap, float , Stride<Int<1 >, Int<0 >, Int<0 >>>;
102
90
103
91
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
104
92
OutputTileThreadMap, float , Stride<Int<0 >, Int<1 >, Int<0 >>>;
93
+ };
94
+
95
+ /*
96
+ This epilogue function defines a quantized GEMM operation similar to
97
+ torch._scaled_mm.
98
+
99
+ A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
100
+ per-row. B can be quantized per-tensor or per-column.
101
+ Any combination of per-tensor and per-row or column is supported.
102
+ A and B must have symmetric quantization (zero point == 0).
103
+
104
+ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
105
+ scales are applied elementwise with numpy-style broadcasting.
106
+
107
+ ScaleA and ScaleB define the epilogue functions that apply the scales for
108
+ the A and B operands respectively. These scales may be either per-tensor or
109
+ per row or column.
110
+ */
111
+ template <typename ElementD, typename OutputTileThreadMap>
112
+ struct ScaledEpilogue
113
+ : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
114
+ private:
115
+ using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
116
+ using Accum = typename SUPER::Accum;
117
+ using ScaleA = typename SUPER::ScaleA;
118
+ using ScaleB = typename SUPER::ScaleB;
105
119
106
120
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
107
121
cutlass::multiplies, float , float ,
@@ -134,6 +148,53 @@ struct ScaledEpilogue {
134
148
}
135
149
};
136
150
151
+ template <typename ElementD, typename OutputTileThreadMap>
152
+ struct ScaledEpilogueBias
153
+ : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
154
+ private:
155
+ using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
156
+ using Accum = typename SUPER::Accum;
157
+ using ScaleA = typename SUPER::ScaleA;
158
+ using ScaleB = typename SUPER::ScaleB;
159
+
160
+ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
161
+ cutlass::multiplies, float , float ,
162
+ cutlass::FloatRoundStyle::round_to_nearest>;
163
+
164
+ using EVTCompute0 =
165
+ cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
166
+
167
+ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
168
+ cutlass::multiply_add, ElementD, float ,
169
+ cutlass::FloatRoundStyle::round_to_nearest>;
170
+
171
+ using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
172
+ OutputTileThreadMap, ElementD, Stride<Int<0 >, Int<1 >, Int<0 >>>;
173
+
174
+ public:
175
+ using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
176
+ EVTCompute0, Bias>;
177
+ using ArgumentType = typename EVTCompute::Arguments;
178
+
179
+ static ArgumentType prepare_args (torch::Tensor const & a_scales,
180
+ torch::Tensor const & b_scales,
181
+ torch::Tensor const & bias) {
182
+ using ScaleAArgs = typename ScaleA::Arguments;
183
+ using ScaleBArgs = typename ScaleB::Arguments;
184
+ using BiasArgs = typename Bias::Arguments;
185
+
186
+ ScaleBArgs b_args{b_scales.data_ptr <float >(), b_scales.numel () != 1 , {}};
187
+ ScaleAArgs a_args{a_scales.data_ptr <float >(), a_scales.numel () != 1 , {}};
188
+ BiasArgs bias_args{static_cast <ElementD*>(bias.data_ptr ()), {}};
189
+
190
+ typename EVTCompute0::Arguments evt0_compute_args{b_args};
191
+
192
+ typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
193
+ bias_args};
194
+ return evt_compute_args;
195
+ }
196
+ };
197
+
137
198
template <typename Arch, template <typename > typename ArchGuard,
138
199
typename ElementAB_, typename ElementD_,
139
200
template <typename , typename > typename Epilogue_, typename TileShape,
@@ -168,13 +229,13 @@ struct cutlass_2x_gemm {
168
229
// clang-format off
169
230
using RowMajor = typename cutlass::layout::RowMajor;
170
231
using ColumnMajor = typename cutlass::layout::ColumnMajor;
171
- using KernelType =
232
+ using KernelType =
172
233
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
173
- ElementAB, RowMajor, cutlass::ComplexTransform::kNone , 16 ,
174
- ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone , 16 ,
234
+ ElementAB, RowMajor, cutlass::ComplexTransform::kNone , 16 ,
235
+ ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone , 16 ,
175
236
float , cutlass::layout::RowMajor, 4 ,
176
- ElementAcc, float , cutlass::arch::OpClassTensorOp,
177
- Arch,
237
+ ElementAcc, float , cutlass::arch::OpClassTensorOp,
238
+ Arch,
178
239
TileShape, WarpShape, InstructionShape,
179
240
EVTD,
180
241
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
@@ -404,14 +465,13 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a,
404
465
}
405
466
}
406
467
407
- void cutlass_scaled_mm_sm75 (torch::Tensor& out, torch::Tensor const & a,
408
- torch::Tensor const & b,
409
- torch::Tensor const & a_scales,
410
- torch::Tensor const & b_scales) {
468
+ template <template <typename , typename > typename Epilogue,
469
+ typename ... EpilogueArgs>
470
+ void cutlass_scaled_mm_sm75_epilogue (torch::Tensor& out, torch::Tensor const & a,
471
+ torch::Tensor const & b,
472
+ EpilogueArgs&&... epilogue_args) {
411
473
TORCH_CHECK (a.dtype () == torch::kInt8 );
412
474
TORCH_CHECK (b.dtype () == torch::kInt8 );
413
- TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
414
- TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
415
475
416
476
using TileShape = typename cutlass::gemm::GemmShape<128 , 128 , 64 >;
417
477
using WarpShape = typename cutlass::gemm::GemmShape<64 , 64 , 64 >;
@@ -420,78 +480,130 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
420
480
if (out.dtype () == torch::kBFloat16 ) {
421
481
return cutlass_gemm_caller<cutlass_2x_gemm<
422
482
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t , cutlass::bfloat16_t ,
423
- ScaledEpilogue , TileShape, WarpShape, InstructionShape, 2 >>(
424
- out, a, b, a_scales, b_scales );
483
+ Epilogue , TileShape, WarpShape, InstructionShape, 2 >>(
484
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
425
485
} else {
426
486
TORCH_CHECK (out.dtype () == torch::kFloat16 );
427
487
return cutlass_gemm_caller<cutlass_2x_gemm<
428
488
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t , cutlass::half_t ,
429
- ScaledEpilogue , TileShape, WarpShape, InstructionShape, 2 >>(
430
- out, a, b, a_scales, b_scales );
489
+ Epilogue , TileShape, WarpShape, InstructionShape, 2 >>(
490
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
431
491
}
432
492
}
433
493
434
- void cutlass_scaled_mm_sm80 (torch::Tensor& out, torch::Tensor const & a,
494
+ void cutlass_scaled_mm_sm75 (torch::Tensor& out, torch::Tensor const & a,
435
495
torch::Tensor const & b,
436
496
torch::Tensor const & a_scales,
437
- torch::Tensor const & b_scales) {
438
- TORCH_CHECK (a.dtype () == torch::kInt8 );
439
- TORCH_CHECK (b.dtype () == torch::kInt8 );
497
+ torch::Tensor const & b_scales,
498
+ c10::optional<torch::Tensor> const & bias) {
440
499
TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
441
500
TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
501
+ if (bias) {
502
+ TORCH_CHECK (bias->dtype () == out.dtype (),
503
+ " currently bias dtype must match output dtype " , out.dtype ());
504
+ return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogueBias>(
505
+ out, a, b, a_scales, b_scales, *bias);
506
+ } else {
507
+ return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogue>(out, a, b, a_scales,
508
+ b_scales);
509
+ }
510
+ }
511
+
512
+ template <template <typename , typename > typename Epilogue,
513
+ typename ... EpilogueArgs>
514
+ void cutlass_scaled_mm_sm80_epilogue (torch::Tensor& out, torch::Tensor const & a,
515
+ torch::Tensor const & b,
516
+ EpilogueArgs&&... epilogue_args) {
517
+ TORCH_CHECK (a.dtype () == torch::kInt8 );
518
+ TORCH_CHECK (b.dtype () == torch::kInt8 );
442
519
443
520
if (out.dtype () == torch::kBFloat16 ) {
444
- return cutlass_gemm_sm80_dispatch<int8_t , cutlass::bfloat16_t ,
445
- ScaledEpilogue>(out, a, b, a_scales,
446
- b_scales);
521
+ return cutlass_gemm_sm80_dispatch<int8_t , cutlass::bfloat16_t , Epilogue>(
522
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
447
523
} else {
448
524
TORCH_CHECK (out.dtype () == torch::kFloat16 );
449
- return cutlass_gemm_sm80_dispatch<int8_t , cutlass::half_t , ScaledEpilogue >(
450
- out, a, b, a_scales, b_scales );
525
+ return cutlass_gemm_sm80_dispatch<int8_t , cutlass::half_t , Epilogue >(
526
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
451
527
}
452
528
}
453
529
454
- void cutlass_scaled_mm_sm89 (torch::Tensor& out, torch::Tensor const & a,
530
+ void cutlass_scaled_mm_sm80 (torch::Tensor& out, torch::Tensor const & a,
455
531
torch::Tensor const & b,
456
532
torch::Tensor const & a_scales,
457
- torch::Tensor const & b_scales) {
533
+ torch::Tensor const & b_scales,
534
+ c10::optional<torch::Tensor> const & bias) {
535
+ TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
536
+ TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
537
+ if (bias) {
538
+ TORCH_CHECK (bias->dtype () == out.dtype (),
539
+ " currently bias dtype must match output dtype " , out.dtype ());
540
+ return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogueBias>(
541
+ out, a, b, a_scales, b_scales, *bias);
542
+ } else {
543
+ return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogue>(out, a, b, a_scales,
544
+ b_scales);
545
+ }
546
+ }
547
+
548
+ template <template <typename , typename > typename Epilogue,
549
+ typename ... EpilogueArgs>
550
+ void cutlass_scaled_mm_sm89_epilogue (torch::Tensor& out, torch::Tensor const & a,
551
+ torch::Tensor const & b,
552
+ EpilogueArgs&&... epilogue_args) {
458
553
using TileShape = typename cutlass::gemm::GemmShape<128 , 128 , 64 >;
459
554
using WarpShape = typename cutlass::gemm::GemmShape<64 , 64 , 64 >;
460
555
using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
461
556
462
- TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
463
- TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
464
-
465
557
if (a.dtype () == torch::kInt8 ) {
466
558
TORCH_CHECK (b.dtype () == torch::kInt8 );
467
559
468
560
if (out.dtype () == torch::kBFloat16 ) {
469
561
return cutlass_gemm_caller<cutlass_2x_gemm<
470
562
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t , cutlass::bfloat16_t ,
471
- ScaledEpilogue , TileShape, WarpShape, InstructionShape, 5 >>(
472
- out, a, b, a_scales, b_scales );
563
+ Epilogue , TileShape, WarpShape, InstructionShape, 5 >>(
564
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
473
565
} else {
474
566
assert (out.dtype () == torch::kFloat16 );
475
567
return cutlass_gemm_caller<cutlass_2x_gemm<
476
568
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t , cutlass::half_t ,
477
- ScaledEpilogue , TileShape, WarpShape, InstructionShape, 5 >>(
478
- out, a, b, a_scales, b_scales );
569
+ Epilogue , TileShape, WarpShape, InstructionShape, 5 >>(
570
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)... );
479
571
}
480
572
} else {
481
573
TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
482
574
TORCH_CHECK (b.dtype () == torch::kFloat8_e4m3fn );
483
575
484
576
if (out.dtype () == torch::kBFloat16 ) {
485
- return cutlass_gemm_caller<cutlass_2x_gemm<
486
- cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t ,
487
- cutlass::bfloat16_t , ScaledEpilogue, TileShape, WarpShape,
488
- InstructionShape, 5 >>(out, a, b, a_scales, b_scales);
577
+ return cutlass_gemm_caller<
578
+ cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
579
+ cutlass::float_e4m3_t , cutlass::bfloat16_t , Epilogue,
580
+ TileShape, WarpShape, InstructionShape, 5 >>(
581
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
489
582
} else {
490
583
TORCH_CHECK (out.dtype () == torch::kFloat16 );
491
- return cutlass_gemm_caller<cutlass_2x_gemm<
492
- cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t ,
493
- cutlass::half_t , ScaledEpilogue, TileShape, WarpShape,
494
- InstructionShape, 5 >>(out, a, b, a_scales, b_scales);
584
+ return cutlass_gemm_caller<
585
+ cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
586
+ cutlass::float_e4m3_t , cutlass::half_t , Epilogue,
587
+ TileShape, WarpShape, InstructionShape, 5 >>(
588
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
495
589
}
496
590
}
497
591
}
592
+
593
+ void cutlass_scaled_mm_sm89 (torch::Tensor& out, torch::Tensor const & a,
594
+ torch::Tensor const & b,
595
+ torch::Tensor const & a_scales,
596
+ torch::Tensor const & b_scales,
597
+ c10::optional<torch::Tensor> const & bias) {
598
+ TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
599
+ TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
600
+ if (bias) {
601
+ TORCH_CHECK (bias->dtype () == out.dtype (),
602
+ " currently bias dtype must match output dtype " , out.dtype ());
603
+ return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogueBias>(
604
+ out, a, b, a_scales, b_scales, *bias);
605
+ } else {
606
+ return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogue>(out, a, b, a_scales,
607
+ b_scales);
608
+ }
609
+ }
0 commit comments