Skip to content

Commit dd2e3ea

Browse files
ProExpertProgcyang49LucasWilkinson
authored andcommitted
[Kernel] Adding bias epilogue support for cutlass_scaled_mm (vllm-project#5560)
Co-authored-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent 9cdb22d commit dd2e3ea

File tree

8 files changed

+383
-134
lines changed

8 files changed

+383
-134
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ cmake_minimum_required(VERSION 3.21)
22

33
project(vllm_extensions LANGUAGES CXX)
44

5-
option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda")
5+
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
6+
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
67

78
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
89
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")

csrc/ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
9696

9797
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
9898
torch::Tensor const& b, torch::Tensor const& a_scales,
99-
torch::Tensor const& b_scales);
99+
torch::Tensor const& b_scales,
100+
c10::optional<torch::Tensor> const& bias);
100101

101102
#endif
102103

csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu

Lines changed: 170 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -77,31 +77,45 @@ struct enable_sm89_to_sm90 : Kernel {
7777
};
7878

7979
/*
80-
This epilogue function defines a quantized GEMM operation similar to
81-
torch._scaled_mm.
82-
83-
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
84-
per-row. B can be quantized per-tensor or per-column.
85-
Any combination of per-tensor and per-row or column is supported.
86-
A and B must have symmetric quantization (zero point == 0).
87-
88-
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
89-
scales are applied elementwise with numpy-style broadcasting.
90-
91-
ScaleA and ScaleB define the epilogue functions that apply the scales for
92-
the A and B operands respectively. These scales may be either per-tensor or
93-
per row or column.
94-
*/
80+
* This class provides the common ScaleA and ScaleB descriptors for the
81+
* ScaledEpilogue and ScaledEpilogueBias classes.
82+
*/
9583
template <typename ElementD, typename OutputTileThreadMap>
96-
struct ScaledEpilogue {
97-
private:
84+
struct ScaledEpilogueBase {
85+
protected:
9886
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
9987

10088
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
10189
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
10290

10391
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
10492
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
93+
};
94+
95+
/*
96+
This epilogue function defines a quantized GEMM operation similar to
97+
torch._scaled_mm.
98+
99+
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
100+
per-row. B can be quantized per-tensor or per-column.
101+
Any combination of per-tensor and per-row or column is supported.
102+
A and B must have symmetric quantization (zero point == 0).
103+
104+
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
105+
scales are applied elementwise with numpy-style broadcasting.
106+
107+
ScaleA and ScaleB define the epilogue functions that apply the scales for
108+
the A and B operands respectively. These scales may be either per-tensor or
109+
per row or column.
110+
*/
111+
template <typename ElementD, typename OutputTileThreadMap>
112+
struct ScaledEpilogue
113+
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
114+
private:
115+
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
116+
using Accum = typename SUPER::Accum;
117+
using ScaleA = typename SUPER::ScaleA;
118+
using ScaleB = typename SUPER::ScaleB;
105119

106120
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
107121
cutlass::multiplies, float, float,
@@ -134,6 +148,53 @@ struct ScaledEpilogue {
134148
}
135149
};
136150

151+
template <typename ElementD, typename OutputTileThreadMap>
152+
struct ScaledEpilogueBias
153+
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
154+
private:
155+
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
156+
using Accum = typename SUPER::Accum;
157+
using ScaleA = typename SUPER::ScaleA;
158+
using ScaleB = typename SUPER::ScaleB;
159+
160+
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
161+
cutlass::multiplies, float, float,
162+
cutlass::FloatRoundStyle::round_to_nearest>;
163+
164+
using EVTCompute0 =
165+
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
166+
167+
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
168+
cutlass::multiply_add, ElementD, float,
169+
cutlass::FloatRoundStyle::round_to_nearest>;
170+
171+
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
172+
OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
173+
174+
public:
175+
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
176+
EVTCompute0, Bias>;
177+
using ArgumentType = typename EVTCompute::Arguments;
178+
179+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
180+
torch::Tensor const& b_scales,
181+
torch::Tensor const& bias) {
182+
using ScaleAArgs = typename ScaleA::Arguments;
183+
using ScaleBArgs = typename ScaleB::Arguments;
184+
using BiasArgs = typename Bias::Arguments;
185+
186+
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
187+
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
188+
BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
189+
190+
typename EVTCompute0::Arguments evt0_compute_args{b_args};
191+
192+
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
193+
bias_args};
194+
return evt_compute_args;
195+
}
196+
};
197+
137198
template <typename Arch, template <typename> typename ArchGuard,
138199
typename ElementAB_, typename ElementD_,
139200
template <typename, typename> typename Epilogue_, typename TileShape,
@@ -168,13 +229,13 @@ struct cutlass_2x_gemm {
168229
// clang-format off
169230
using RowMajor = typename cutlass::layout::RowMajor;
170231
using ColumnMajor = typename cutlass::layout::ColumnMajor;
171-
using KernelType =
232+
using KernelType =
172233
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
173-
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
174-
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
234+
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
235+
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
175236
float, cutlass::layout::RowMajor, 4,
176-
ElementAcc, float, cutlass::arch::OpClassTensorOp,
177-
Arch,
237+
ElementAcc, float, cutlass::arch::OpClassTensorOp,
238+
Arch,
178239
TileShape, WarpShape, InstructionShape,
179240
EVTD,
180241
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
@@ -404,14 +465,13 @@ void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a,
404465
}
405466
}
406467

407-
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
408-
torch::Tensor const& b,
409-
torch::Tensor const& a_scales,
410-
torch::Tensor const& b_scales) {
468+
template <template <typename, typename> typename Epilogue,
469+
typename... EpilogueArgs>
470+
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
471+
torch::Tensor const& b,
472+
EpilogueArgs&&... epilogue_args) {
411473
TORCH_CHECK(a.dtype() == torch::kInt8);
412474
TORCH_CHECK(b.dtype() == torch::kInt8);
413-
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
414-
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
415475

416476
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
417477
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
@@ -420,78 +480,130 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
420480
if (out.dtype() == torch::kBFloat16) {
421481
return cutlass_gemm_caller<cutlass_2x_gemm<
422482
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
423-
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
424-
out, a, b, a_scales, b_scales);
483+
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
484+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
425485
} else {
426486
TORCH_CHECK(out.dtype() == torch::kFloat16);
427487
return cutlass_gemm_caller<cutlass_2x_gemm<
428488
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
429-
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
430-
out, a, b, a_scales, b_scales);
489+
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
490+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
431491
}
432492
}
433493

434-
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
494+
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
435495
torch::Tensor const& b,
436496
torch::Tensor const& a_scales,
437-
torch::Tensor const& b_scales) {
438-
TORCH_CHECK(a.dtype() == torch::kInt8);
439-
TORCH_CHECK(b.dtype() == torch::kInt8);
497+
torch::Tensor const& b_scales,
498+
c10::optional<torch::Tensor> const& bias) {
440499
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
441500
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
501+
if (bias) {
502+
TORCH_CHECK(bias->dtype() == out.dtype(),
503+
"currently bias dtype must match output dtype ", out.dtype());
504+
return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogueBias>(
505+
out, a, b, a_scales, b_scales, *bias);
506+
} else {
507+
return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogue>(out, a, b, a_scales,
508+
b_scales);
509+
}
510+
}
511+
512+
template <template <typename, typename> typename Epilogue,
513+
typename... EpilogueArgs>
514+
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
515+
torch::Tensor const& b,
516+
EpilogueArgs&&... epilogue_args) {
517+
TORCH_CHECK(a.dtype() == torch::kInt8);
518+
TORCH_CHECK(b.dtype() == torch::kInt8);
442519

443520
if (out.dtype() == torch::kBFloat16) {
444-
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
445-
ScaledEpilogue>(out, a, b, a_scales,
446-
b_scales);
521+
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
522+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
447523
} else {
448524
TORCH_CHECK(out.dtype() == torch::kFloat16);
449-
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, ScaledEpilogue>(
450-
out, a, b, a_scales, b_scales);
525+
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
526+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
451527
}
452528
}
453529

454-
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
530+
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
455531
torch::Tensor const& b,
456532
torch::Tensor const& a_scales,
457-
torch::Tensor const& b_scales) {
533+
torch::Tensor const& b_scales,
534+
c10::optional<torch::Tensor> const& bias) {
535+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
536+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
537+
if (bias) {
538+
TORCH_CHECK(bias->dtype() == out.dtype(),
539+
"currently bias dtype must match output dtype ", out.dtype());
540+
return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogueBias>(
541+
out, a, b, a_scales, b_scales, *bias);
542+
} else {
543+
return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogue>(out, a, b, a_scales,
544+
b_scales);
545+
}
546+
}
547+
548+
template <template <typename, typename> typename Epilogue,
549+
typename... EpilogueArgs>
550+
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
551+
torch::Tensor const& b,
552+
EpilogueArgs&&... epilogue_args) {
458553
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
459554
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
460555
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
461556

462-
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
463-
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
464-
465557
if (a.dtype() == torch::kInt8) {
466558
TORCH_CHECK(b.dtype() == torch::kInt8);
467559

468560
if (out.dtype() == torch::kBFloat16) {
469561
return cutlass_gemm_caller<cutlass_2x_gemm<
470562
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
471-
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
472-
out, a, b, a_scales, b_scales);
563+
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
564+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
473565
} else {
474566
assert(out.dtype() == torch::kFloat16);
475567
return cutlass_gemm_caller<cutlass_2x_gemm<
476568
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
477-
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
478-
out, a, b, a_scales, b_scales);
569+
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
570+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
479571
}
480572
} else {
481573
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
482574
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
483575

484576
if (out.dtype() == torch::kBFloat16) {
485-
return cutlass_gemm_caller<cutlass_2x_gemm<
486-
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
487-
cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape,
488-
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
577+
return cutlass_gemm_caller<
578+
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
579+
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue,
580+
TileShape, WarpShape, InstructionShape, 5>>(
581+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
489582
} else {
490583
TORCH_CHECK(out.dtype() == torch::kFloat16);
491-
return cutlass_gemm_caller<cutlass_2x_gemm<
492-
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
493-
cutlass::half_t, ScaledEpilogue, TileShape, WarpShape,
494-
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
584+
return cutlass_gemm_caller<
585+
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
586+
cutlass::float_e4m3_t, cutlass::half_t, Epilogue,
587+
TileShape, WarpShape, InstructionShape, 5>>(
588+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
495589
}
496590
}
497591
}
592+
593+
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
594+
torch::Tensor const& b,
595+
torch::Tensor const& a_scales,
596+
torch::Tensor const& b_scales,
597+
c10::optional<torch::Tensor> const& bias) {
598+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
599+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
600+
if (bias) {
601+
TORCH_CHECK(bias->dtype() == out.dtype(),
602+
"currently bias dtype must match output dtype ", out.dtype());
603+
return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogueBias>(
604+
out, a, b, a_scales, b_scales, *bias);
605+
} else {
606+
return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogue>(out, a, b, a_scales,
607+
b_scales);
608+
}
609+
}

0 commit comments

Comments
 (0)