44#include " cutlass/float8.h"
55
66/* *
7- * This file defines Gemm kernel configurations for SM89 based on the Gemm
7+ * This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
88 * shape.
99 */
1010
1111namespace vllm {
1212
1313template <typename InType, typename OutType,
1414 template <typename , typename > typename Epilogue>
15- struct sm89_fallback_gemm {
15+ struct sm89_fp8_fallback_gemm {
1616 // Shared Memory required by this Gemm - 61440 bytes
1717 static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
1818 using TileShape = typename cutlass::gemm::GemmShape<64 , 128 , 64 >;
@@ -25,7 +25,7 @@ struct sm89_fallback_gemm {
2525 FP8MathOperator>;
2626};
2727
28- struct sm89_config_default {
28+ struct sm89_fp8_config_default {
2929 // M in (256, inf)
3030 using WarpShape = typename cutlass::gemm::GemmShape<64 , 64 , 64 >;
3131 using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
@@ -40,7 +40,8 @@ struct sm89_config_default {
4040 TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
4141
4242 using FallbackGemm =
43- typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
43+ typename sm89_fp8_fallback_gemm<InType, OutType,
44+ Epilogue>::Cutlass2xGemm;
4445
4546 uint32_t const n = out.size (1 );
4647 uint32_t const np2 = next_pow_2 (n);
@@ -74,7 +75,7 @@ struct sm89_config_default {
7475 }
7576};
7677
77- struct sm89_config_M256 {
78+ struct sm89_fp8_config_M256 {
7879 // M in (128, 256]
7980 using WarpShape = typename cutlass::gemm::GemmShape<64 , 64 , 64 >;
8081 using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
@@ -89,7 +90,8 @@ struct sm89_config_M256 {
8990 TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
9091
9192 using FallbackGemm =
92- typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
93+ typename sm89_fp8_fallback_gemm<InType, OutType,
94+ Epilogue>::Cutlass2xGemm;
9395
9496 uint32_t const n = out.size (1 );
9597 uint32_t const np2 = next_pow_2 (n);
@@ -114,7 +116,7 @@ struct sm89_config_M256 {
114116 }
115117};
116118
117- struct sm89_config_M128 {
119+ struct sm89_fp8_config_M128 {
118120 // M in (64, 128]
119121 using WarpShape = typename cutlass::gemm::GemmShape<64 , 64 , 64 >;
120122 using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
@@ -129,7 +131,8 @@ struct sm89_config_M128 {
129131 TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
130132
131133 using FallbackGemm =
132- typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
134+ typename sm89_fp8_fallback_gemm<InType, OutType,
135+ Epilogue>::Cutlass2xGemm;
133136
134137 uint32_t const n = out.size (1 );
135138 uint32_t const np2 = next_pow_2 (n);
@@ -163,7 +166,7 @@ struct sm89_config_M128 {
163166 }
164167};
165168
166- struct sm89_config_M64 {
169+ struct sm89_fp8_config_M64 {
167170 // M in (32, 64]
168171 using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
169172
@@ -176,7 +179,8 @@ struct sm89_config_M64 {
176179 TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
177180
178181 using FallbackGemm =
179- typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
182+ typename sm89_fp8_fallback_gemm<InType, OutType,
183+ Epilogue>::Cutlass2xGemm;
180184
181185 uint32_t const n = out.size (1 );
182186 uint32_t const np2 = next_pow_2 (n);
@@ -215,7 +219,7 @@ struct sm89_config_M64 {
215219 }
216220};
217221
218- struct sm89_config_M32 {
222+ struct sm89_fp8_config_M32 {
219223 // M in (16, 32]
220224 using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
221225 using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
@@ -229,7 +233,8 @@ struct sm89_config_M32 {
229233 TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
230234
231235 using FallbackGemm =
232- typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
236+ typename sm89_fp8_fallback_gemm<InType, OutType,
237+ Epilogue>::Cutlass2xGemm;
233238
234239 uint32_t const n = out.size (1 );
235240 uint32_t const np2 = next_pow_2 (n);
@@ -265,7 +270,7 @@ struct sm89_config_M32 {
265270 }
266271};
267272
268- struct sm89_config_M16 {
273+ struct sm89_fp8_config_M16 {
269274 // M in [1, 16]
270275 using WarpShape = typename cutlass::gemm::GemmShape<16 , 64 , 64 >;
271276 using InstructionShape = typename cutlass::gemm::GemmShape<16 , 8 , 32 >;
@@ -281,7 +286,8 @@ struct sm89_config_M16 {
281286 TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
282287
283288 using FallbackGemm =
284- typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
289+ typename sm89_fp8_fallback_gemm<InType, OutType,
290+ Epilogue>::Cutlass2xGemm;
285291
286292 uint32_t const n = out.size (1 );
287293 uint32_t const np2 = next_pow_2 (n);
@@ -320,10 +326,10 @@ struct sm89_config_M16 {
320326template <typename InType, typename OutType,
321327 template <typename , typename > typename Epilogue,
322328 typename ... EpilogueArgs>
323- inline void cutlass_gemm_sm89_dispatch (torch::Tensor& out,
324- torch::Tensor const & a,
325- torch::Tensor const & b,
326- EpilogueArgs&&... args) {
329+ inline void cutlass_gemm_sm89_fp8_dispatch (torch::Tensor& out,
330+ torch::Tensor const & a,
331+ torch::Tensor const & b,
332+ EpilogueArgs&&... args) {
327333 static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
328334 TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn );
329335 TORCH_CHECK (b.dtype () == torch::kFloat8_e4m3fn );
@@ -334,27 +340,27 @@ inline void cutlass_gemm_sm89_dispatch(torch::Tensor& out,
334340
335341 if (mp2 <= 16 ) {
336342 // M in [1, 16]
337- return sm89_config_M16 ::dispatch<InType, OutType, Epilogue>(
343+ return sm89_fp8_config_M16 ::dispatch<InType, OutType, Epilogue>(
338344 out, a, b, std::forward<EpilogueArgs>(args)...);
339345 } else if (mp2 <= 32 ) {
340346 // M in (16, 32]
341- return sm89_config_M32 ::dispatch<InType, OutType, Epilogue>(
347+ return sm89_fp8_config_M32 ::dispatch<InType, OutType, Epilogue>(
342348 out, a, b, std::forward<EpilogueArgs>(args)...);
343349 } else if (mp2 <= 64 ) {
344350 // M in (32, 64]
345- return sm89_config_M64 ::dispatch<InType, OutType, Epilogue>(
351+ return sm89_fp8_config_M64 ::dispatch<InType, OutType, Epilogue>(
346352 out, a, b, std::forward<EpilogueArgs>(args)...);
347353 } else if (mp2 <= 128 ) {
348354 // M in (64, 128]
349- return sm89_config_M128 ::dispatch<InType, OutType, Epilogue>(
355+ return sm89_fp8_config_M128 ::dispatch<InType, OutType, Epilogue>(
350356 out, a, b, std::forward<EpilogueArgs>(args)...);
351357 } else if (mp2 <= 256 ) {
352358 // M in (128, 256]
353- return sm89_config_M256 ::dispatch<InType, OutType, Epilogue>(
359+ return sm89_fp8_config_M256 ::dispatch<InType, OutType, Epilogue>(
354360 out, a, b, std::forward<EpilogueArgs>(args)...);
355361 } else {
356362 // M in (256, inf)
357- return sm89_config_default ::dispatch<InType, OutType, Epilogue>(
363+ return sm89_fp8_config_default ::dispatch<InType, OutType, Epilogue>(
358364 out, a, b, std::forward<EpilogueArgs>(args)...);
359365 }
360366}
0 commit comments