Skip to content

Commit a0d508b

Browse files
kushanampathorn
authored andcommitted
Apply vllm-project#13798 add cutlass support for blackwell fp8 gemm
1 parent 52e631b commit a0d508b

File tree

9 files changed

+173
-64
lines changed

9 files changed

+173
-64
lines changed

CMakeLists.txt

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
279279
# Only build Marlin kernels if we are building for at least some compatible archs.
280280
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
281281
# are not supported by Machete yet.
282-
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1" "${CUDA_ARCHS}")
282+
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
283283
if (MARLIN_ARCHS)
284284
set(MARLIN_SRCS
285285
"csrc/quantization/fp8/fp8_marlin.cu"
@@ -301,7 +301,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
301301

302302
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
303303
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
304-
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a" "${CUDA_ARCHS}")
304+
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
305305
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
306306
set(SRCS
307307
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
@@ -336,7 +336,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
336336
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
337337
# kernels for the remaining archs that are not already built for 3x.
338338
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
339-
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1" "${CUDA_ARCHS}")
339+
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
340340
# subtract out the archs that are already built for 3x
341341
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
342342
if (SCALED_MM_2X_ARCHS)
@@ -361,7 +361,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
361361
# 2:4 Sparse Kernels
362362

363363
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
364-
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
364+
# require CUDA 12.2 or later (and only work on Hopper and Blackwell).
365365
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
366366
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
367367
set_gencode_flags_for_srcs(
@@ -399,6 +399,19 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
399399
set(FP4_ARCHS)
400400
endif()
401401

402+
# FP8 Blackwell Archs
403+
cuda_archs_loose_intersection(BLACKWELL_ARCHS "10.0;10.1;12.0" "${CUDA_ARCHS}")
404+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND BLACKWELL_ARCHS)
405+
set(SRCS
406+
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
407+
)
408+
list(APPEND VLLM_EXT_SRC "${SRCS}")
409+
message(STATUS "Building FP8 for archs: ${BLACKWELL_ARCHS}")
410+
else()
411+
# clear BLACKWELL_ARCHS
412+
set(BLACKWELL_ARCHS)
413+
endif()
414+
402415
#
403416
# Machete kernels
404417

@@ -507,7 +520,7 @@ set_gencode_flags_for_srcs(
507520
CUDA_ARCHS "${CUDA_ARCHS}")
508521

509522
if(VLLM_GPU_LANG STREQUAL "CUDA")
510-
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1" "${CUDA_ARCHS}")
523+
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
511524
if (MARLIN_MOE_ARCHS)
512525
set(MARLIN_MOE_SRC
513526
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"

csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct identity {
2222
T operator()(T lhs) const { return lhs; }
2323
};
2424

25-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
25+
template <typename ElementAcc, typename ElementD, typename TileShape>
2626
struct TrivialEpilogue {
2727
private:
2828
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
@@ -44,32 +44,30 @@ struct TrivialEpilogue {
4444
* This class provides the common load descriptors for the
4545
* ScaledEpilogue[...] classes
4646
*/
47-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
47+
template <typename ElementAcc, typename ElementD, typename TileShape>
4848
struct ScaledEpilogueBase {
4949
protected:
5050
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
5151

5252
template <typename T>
5353
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
54-
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
55-
Stride<Int<1>, Int<0>, Int<0>>>;
54+
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
5655

5756
template <typename T>
5857
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
59-
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
60-
Stride<Int<0>, Int<1>, Int<0>>>;
58+
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
6159

6260
// Don't want to support nullptr by default
6361
template <typename T, bool EnableNullPtr = false>
6462
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
65-
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
66-
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
63+
0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
64+
128 / sizeof_bits_v<T>, EnableNullPtr>;
6765

6866
// Don't want to support nullptr by default
6967
template <typename T, bool EnableNullPtr = false>
7068
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
71-
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
72-
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
69+
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
70+
128 / sizeof_bits_v<T>, EnableNullPtr>;
7371

7472
// This utility function constructs the arguments for the load descriptors
7573
// from a tensor. It can handle both row and column, as well as row/column or
@@ -116,11 +114,11 @@ struct ScaledEpilogueBase {
116114
the A and B operands respectively. These scales may be either per-tensor or
117115
per row or column.
118116
*/
119-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
117+
template <typename ElementAcc, typename ElementD, typename TileShape>
120118
struct ScaledEpilogue
121-
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
119+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
122120
private:
123-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
121+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
124122
using Accum = typename SUPER::Accum;
125123
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
126124
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
@@ -160,11 +158,11 @@ struct ScaledEpilogue
160158
* The bias tensor must be per-output channel.
161159
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
162160
*/
163-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
161+
template <typename ElementAcc, typename ElementD, typename TileShape>
164162
struct ScaledEpilogueBias
165-
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
163+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
166164
private:
167-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
165+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
168166
using Accum = typename SUPER::Accum;
169167
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
170168
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
@@ -203,11 +201,11 @@ struct ScaledEpilogueBias
203201
* bias is a column vector instead of a row vector. Useful e.g. if we are
204202
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
205203
*/
206-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
204+
template <typename ElementAcc, typename ElementD, typename TileShape>
207205
struct ScaledEpilogueColumnBias
208-
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
206+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
209207
private:
210-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
208+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
211209
using Accum = typename SUPER::Accum;
212210
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
213211
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
@@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias
249247
*
250248
* This epilogue also supports bias, which remains per-channel.
251249
*/
252-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
250+
template <typename ElementAcc, typename ElementD, typename TileShape>
253251
struct ScaledEpilogueBiasAzp
254-
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
252+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
255253
private:
256-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
254+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
257255
using Accum = typename SUPER::Accum;
258256
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
259257
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
@@ -313,11 +311,11 @@ struct ScaledEpilogueBiasAzp
313311
*
314312
* This epilogue also supports bias, which remains per-channel.
315313
*/
316-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
314+
template <typename ElementAcc, typename ElementD, typename TileShape>
317315
struct ScaledEpilogueBiasAzpToken
318-
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
316+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
319317
private:
320-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
318+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
321319
using Accum = typename SUPER::Accum;
322320
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
323321
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;

csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "cutlass/gemm/kernel/gemm_universal.hpp"
1717
#include "cutlass/epilogue/collective/collective_builder.hpp"
1818
#include "cutlass/gemm/collective/collective_builder.hpp"
19+
#include "cutlass/util/packed_stride.hpp"
1920

2021
#include "core/math.hpp"
2122
#include "cutlass_extensions/common.hpp"
@@ -58,33 +59,40 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
5859
torch::Tensor const& b,
5960
EpilogueArgs&&... epilogue_params) {
6061
using ElementAB = typename Gemm::ElementAB;
62+
using ElementC = typename Gemm::ElementC;
6163
using ElementD = typename Gemm::ElementD;
6264
using GemmKernel = typename Gemm::GemmKernel;
6365

64-
int64_t lda = a.stride(0);
65-
int64_t ldb = b.stride(1);
66-
int64_t ldc = out.stride(0);
67-
68-
using StrideA = cute::Stride<int64_t, cute::Int<1>, int64_t>;
69-
using StrideB = cute::Stride<int64_t, cute::Int<1>, int64_t>;
70-
using StrideC = typename Gemm::StrideC;
71-
72-
StrideA a_stride{lda, cute::Int<1>{}, 0};
73-
StrideB b_stride{ldb, cute::Int<1>{}, 0};
74-
StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}};
66+
using StrideA = typename Gemm::GemmKernel::StrideA;
67+
using StrideB = typename Gemm::GemmKernel::StrideB;
68+
using StrideC = typename Gemm::GemmKernel::StrideC;
69+
using StrideD = StrideC;
70+
using StrideAux = StrideC;
7571

7672
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
73+
auto [M, N, K, L] = prob_shape;
74+
75+
StrideA a_stride =
76+
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
77+
StrideB b_stride =
78+
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
79+
StrideC c_stride =
80+
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
81+
StrideD d_stride =
82+
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
83+
StrideAux aux_stride = d_stride;
7784

7885
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
7986
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
8087
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
8188
b_stride};
8289

8390
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
91+
// auto d_ptr = static_cast<ElementC*>(out.data_ptr());
8492
typename GemmKernel::EpilogueArguments epilogue_args{
8593
Gemm::Epilogue::prepare_args(
8694
std::forward<EpilogueArgs>(epilogue_params)...),
87-
c_ptr, c_stride, c_ptr, c_stride};
95+
c_ptr, c_stride, c_ptr, d_stride};
8896

8997
cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
9098
epilogue_args);

csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,7 @@ struct cutlass_3x_gemm {
4040
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
4141
float>::type;
4242

43-
using EpilogueDescriptor =
44-
cutlass::epilogue::collective::detail::EpilogueDescriptor<
45-
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
46-
ElementD, EpilogueSchedule>;
47-
48-
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
43+
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
4944

5045
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
5146
using ElementC = void;
@@ -88,4 +83,65 @@ struct cutlass_3x_gemm {
8883
struct GemmKernel : public KernelType {};
8984
};
9085

86+
template <typename ElementAB_, typename ElementD_,
87+
template <typename, typename, typename> typename Epilogue_,
88+
typename TileShape, typename ClusterShape, typename KernelSchedule,
89+
typename EpilogueSchedule>
90+
struct cutlass_3x_gemm_sm100 {
91+
using ElementAB = ElementAB_;
92+
using LayoutA = cutlass::layout::RowMajor;
93+
static constexpr int AlignmentA =
94+
128 / cutlass::sizeof_bits<ElementAB>::value;
95+
96+
using LayoutB = cutlass::layout::ColumnMajor;
97+
static constexpr int AlignmentB =
98+
128 / cutlass::sizeof_bits<ElementAB>::value;
99+
100+
using ElementC = void;
101+
using LayoutC = cutlass::layout::RowMajor;
102+
static constexpr int AlignmentC =
103+
128 / cutlass::sizeof_bits<ElementD_>::value;
104+
105+
using ElementD = ElementD_;
106+
using LayoutD = cutlass::layout::RowMajor;
107+
static constexpr int AlignmentD = AlignmentC;
108+
109+
using ElementAcc =
110+
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
111+
float>::type;
112+
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
113+
114+
// MMA type
115+
using ElementAccumulator = float;
116+
117+
// Epilogue types
118+
using ElementBias = cutlass::half_t;
119+
using ElementCompute = float;
120+
using ElementAux = ElementD;
121+
using LayoutAux = LayoutD;
122+
using ElementAmax = float;
123+
124+
using EVTCompute = typename Epilogue::EVTCompute;
125+
126+
using CollectiveEpilogue =
127+
typename cutlass::epilogue::collective::CollectiveBuilder<
128+
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
129+
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
130+
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
131+
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
132+
EVTCompute>::CollectiveOp;
133+
134+
using CollectiveMainloop =
135+
typename cutlass::gemm::collective::CollectiveBuilder<
136+
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
137+
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
138+
ElementAccumulator, TileShape, ClusterShape,
139+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
140+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
141+
KernelSchedule>::CollectiveOp;
142+
143+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
144+
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
145+
};
146+
91147
} // namespace vllm

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,10 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
3636
torch::Tensor const& a_scales,
3737
torch::Tensor const& b_scales);
3838

39+
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
40+
torch::Tensor const& b,
41+
torch::Tensor const& a_scales,
42+
torch::Tensor const& b_scales,
43+
std::optional<torch::Tensor> const& bias);
44+
3945
} // namespace vllm

csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,28 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
119119
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
120120
azp, bias);
121121
}
122+
123+
#if defined CUDA_VERSION && CUDA_VERSION >= 12800
124+
125+
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
126+
torch::Tensor const& b,
127+
torch::Tensor const& a_scales,
128+
torch::Tensor const& b_scales,
129+
std::optional<torch::Tensor> const& bias) {
130+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
131+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
132+
133+
int M = a.size(0), N = b.size(1), K = a.size(1);
134+
TORCH_CHECK(
135+
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
136+
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
137+
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
138+
139+
// Standard per-tensor/per-token/per-channel scaling
140+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
141+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
142+
"Currently, only fp8 gemm is implemented for Blackwell");
143+
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
144+
}
145+
146+
#endif

0 commit comments

Comments
 (0)