Skip to content

Commit 36750cc

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
authored andcommitted
[Kernel] Tuned int8 kernels for Ada Lovelace (vllm-project#6848)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 9da9b2d commit 36750cc

File tree

4 files changed

+395
-43
lines changed

4 files changed

+395
-43
lines changed

csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
#include "scaled_mm_c2x.cuh"
66
#include "scaled_mm_c2x_sm80_dispatch.cuh"
7-
#include "scaled_mm_c2x_sm89_dispatch.cuh"
7+
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
8+
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
89

910
/*
1011
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
@@ -98,39 +99,31 @@ template <template <typename, typename> typename Epilogue,
9899
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
99100
torch::Tensor const& b,
100101
EpilogueArgs&&... epilogue_args) {
101-
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
102-
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
103-
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
104-
105102
if (a.dtype() == torch::kInt8) {
106103
TORCH_CHECK(b.dtype() == torch::kInt8);
107104

108105
if (out.dtype() == torch::kBFloat16) {
109-
return vllm::cutlass_gemm_caller<
110-
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
111-
int8_t, cutlass::bfloat16_t, Epilogue,
112-
TileShape, WarpShape, InstructionShape, 5>>(
106+
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
107+
Epilogue>(
113108
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
114109
} else {
115110
assert(out.dtype() == torch::kFloat16);
116-
return vllm::cutlass_gemm_caller<
117-
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
118-
int8_t, cutlass::half_t, Epilogue, TileShape,
119-
WarpShape, InstructionShape, 5>>(
111+
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
112+
Epilogue>(
120113
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
121114
}
122115
} else {
123116
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
124117
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
125118

126119
if (out.dtype() == torch::kBFloat16) {
127-
return vllm::cutlass_gemm_sm89_dispatch<cutlass::float_e4m3_t,
128-
cutlass::bfloat16_t, Epilogue>(
120+
return vllm::cutlass_gemm_sm89_fp8_dispatch<
121+
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
129122
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
130123
} else {
131124
TORCH_CHECK(out.dtype() == torch::kFloat16);
132-
return vllm::cutlass_gemm_sm89_dispatch<cutlass::float_e4m3_t,
133-
cutlass::half_t, Epilogue>(
125+
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
126+
cutlass::half_t, Epilogue>(
134127
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
135128
}
136129
}

csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_dispatch.cuh renamed to csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
#include "cutlass/float8.h"
55

66
/**
7-
* This file defines Gemm kernel configurations for SM89 based on the Gemm
7+
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
88
* shape.
99
*/
1010

1111
namespace vllm {
1212

1313
template <typename InType, typename OutType,
1414
template <typename, typename> typename Epilogue>
15-
struct sm89_fallback_gemm {
15+
struct sm89_fp8_fallback_gemm {
1616
// Shared Memory required by this Gemm - 61440 bytes
1717
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
1818
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
@@ -25,7 +25,7 @@ struct sm89_fallback_gemm {
2525
FP8MathOperator>;
2626
};
2727

28-
struct sm89_config_default {
28+
struct sm89_fp8_config_default {
2929
// M in (256, inf)
3030
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
3131
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
@@ -40,7 +40,8 @@ struct sm89_config_default {
4040
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
4141

4242
using FallbackGemm =
43-
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
43+
typename sm89_fp8_fallback_gemm<InType, OutType,
44+
Epilogue>::Cutlass2xGemm;
4445

4546
uint32_t const n = out.size(1);
4647
uint32_t const np2 = next_pow_2(n);
@@ -74,7 +75,7 @@ struct sm89_config_default {
7475
}
7576
};
7677

77-
struct sm89_config_M256 {
78+
struct sm89_fp8_config_M256 {
7879
// M in (128, 256]
7980
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
8081
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
@@ -89,7 +90,8 @@ struct sm89_config_M256 {
8990
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
9091

9192
using FallbackGemm =
92-
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
93+
typename sm89_fp8_fallback_gemm<InType, OutType,
94+
Epilogue>::Cutlass2xGemm;
9395

9496
uint32_t const n = out.size(1);
9597
uint32_t const np2 = next_pow_2(n);
@@ -114,7 +116,7 @@ struct sm89_config_M256 {
114116
}
115117
};
116118

117-
struct sm89_config_M128 {
119+
struct sm89_fp8_config_M128 {
118120
// M in (64, 128]
119121
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
120122
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
@@ -129,7 +131,8 @@ struct sm89_config_M128 {
129131
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
130132

131133
using FallbackGemm =
132-
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
134+
typename sm89_fp8_fallback_gemm<InType, OutType,
135+
Epilogue>::Cutlass2xGemm;
133136

134137
uint32_t const n = out.size(1);
135138
uint32_t const np2 = next_pow_2(n);
@@ -163,7 +166,7 @@ struct sm89_config_M128 {
163166
}
164167
};
165168

166-
struct sm89_config_M64 {
169+
struct sm89_fp8_config_M64 {
167170
// M in (32, 64]
168171
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
169172

@@ -176,7 +179,8 @@ struct sm89_config_M64 {
176179
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
177180

178181
using FallbackGemm =
179-
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
182+
typename sm89_fp8_fallback_gemm<InType, OutType,
183+
Epilogue>::Cutlass2xGemm;
180184

181185
uint32_t const n = out.size(1);
182186
uint32_t const np2 = next_pow_2(n);
@@ -215,7 +219,7 @@ struct sm89_config_M64 {
215219
}
216220
};
217221

218-
struct sm89_config_M32 {
222+
struct sm89_fp8_config_M32 {
219223
// M in (16, 32]
220224
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
221225
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
@@ -229,7 +233,8 @@ struct sm89_config_M32 {
229233
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
230234

231235
using FallbackGemm =
232-
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
236+
typename sm89_fp8_fallback_gemm<InType, OutType,
237+
Epilogue>::Cutlass2xGemm;
233238

234239
uint32_t const n = out.size(1);
235240
uint32_t const np2 = next_pow_2(n);
@@ -265,7 +270,7 @@ struct sm89_config_M32 {
265270
}
266271
};
267272

268-
struct sm89_config_M16 {
273+
struct sm89_fp8_config_M16 {
269274
// M in [1, 16]
270275
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
271276
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
@@ -281,7 +286,8 @@ struct sm89_config_M16 {
281286
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
282287

283288
using FallbackGemm =
284-
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
289+
typename sm89_fp8_fallback_gemm<InType, OutType,
290+
Epilogue>::Cutlass2xGemm;
285291

286292
uint32_t const n = out.size(1);
287293
uint32_t const np2 = next_pow_2(n);
@@ -320,10 +326,10 @@ struct sm89_config_M16 {
320326
template <typename InType, typename OutType,
321327
template <typename, typename> typename Epilogue,
322328
typename... EpilogueArgs>
323-
inline void cutlass_gemm_sm89_dispatch(torch::Tensor& out,
324-
torch::Tensor const& a,
325-
torch::Tensor const& b,
326-
EpilogueArgs&&... args) {
329+
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
330+
torch::Tensor const& a,
331+
torch::Tensor const& b,
332+
EpilogueArgs&&... args) {
327333
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
328334
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
329335
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
@@ -334,27 +340,27 @@ inline void cutlass_gemm_sm89_dispatch(torch::Tensor& out,
334340

335341
if (mp2 <= 16) {
336342
// M in [1, 16]
337-
return sm89_config_M16::dispatch<InType, OutType, Epilogue>(
343+
return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
338344
out, a, b, std::forward<EpilogueArgs>(args)...);
339345
} else if (mp2 <= 32) {
340346
// M in (16, 32]
341-
return sm89_config_M32::dispatch<InType, OutType, Epilogue>(
347+
return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
342348
out, a, b, std::forward<EpilogueArgs>(args)...);
343349
} else if (mp2 <= 64) {
344350
// M in (32, 64]
345-
return sm89_config_M64::dispatch<InType, OutType, Epilogue>(
351+
return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
346352
out, a, b, std::forward<EpilogueArgs>(args)...);
347353
} else if (mp2 <= 128) {
348354
// M in (64, 128]
349-
return sm89_config_M128::dispatch<InType, OutType, Epilogue>(
355+
return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
350356
out, a, b, std::forward<EpilogueArgs>(args)...);
351357
} else if (mp2 <= 256) {
352358
// M in (128, 256]
353-
return sm89_config_M256::dispatch<InType, OutType, Epilogue>(
359+
return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
354360
out, a, b, std::forward<EpilogueArgs>(args)...);
355361
} else {
356362
// M in (256, inf)
357-
return sm89_config_default::dispatch<InType, OutType, Epilogue>(
363+
return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
358364
out, a, b, std::forward<EpilogueArgs>(args)...);
359365
}
360366
}

0 commit comments

Comments
 (0)