Skip to content

Commit 7caff01

Browse files
[Build/BugFix] Fix hopper 12.8 build (#14354)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
1 parent be0b399 commit 7caff01

File tree

4 files changed

+96
-73
lines changed

4 files changed

+96
-73
lines changed

CMakeLists.txt

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -333,36 +333,64 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
333333
" in CUDA target architectures, or CUDA not >= 12.0")
334334
endif()
335335

336+
337+
set(SCALED_MM_3X_ARCHS)
336338
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
337-
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
338-
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
339-
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
339+
# CUDA 12.0 or later
340+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
341+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
340342
set(SRCS
341-
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
343+
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
342344
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
343345
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
344346
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
345347
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
346348
set_gencode_flags_for_srcs(
347349
SRCS "${SRCS}"
348-
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
350+
CUDA_ARCHS "${SCALED_MM_ARCHS}")
349351
list(APPEND VLLM_EXT_SRC "${SRCS}")
350-
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
351-
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
352+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
353+
# Let scaled_mm_c2x know it doesn't need to build these arches
354+
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
355+
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
352356
else()
353-
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
354-
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
357+
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
358+
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
355359
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
356360
"later if you intend on running FP8 quantized models on "
357361
"Hopper.")
358362
else()
359-
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
363+
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
360364
"in CUDA target architectures")
361365
endif()
366+
endif()
362367

363-
# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
364-
# build any 3x kernels
365-
set(SCALED_MM_3X_ARCHS)
368+
# The cutlass_scaled_mm kernels for Blackwell (c3x, i.e. CUTLASS 3.x) require
369+
# CUDA 12.8 or later
370+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
371+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
372+
set(SRCS
373+
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
374+
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
375+
)
376+
set_gencode_flags_for_srcs(
377+
SRCS "${SRCS}"
378+
CUDA_ARCHS "${SCALED_MM_ARCHS}")
379+
list(APPEND VLLM_EXT_SRC "${SRCS}")
380+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
381+
# Let scaled_mm_c2x know it doesn't need to build these arches
382+
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
383+
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
384+
else()
385+
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
386+
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
387+
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
388+
"later if you intend on running FP8 quantized models on "
389+
"Blackwell.")
390+
else()
391+
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
392+
"in CUDA target architectures")
393+
endif()
366394
endif()
367395

368396
#
@@ -395,16 +423,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
395423

396424
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
397425
# require CUDA 12.2 or later (and only work on Hopper and Blackwell).
398-
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
426+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
399427
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
400428
set_gencode_flags_for_srcs(
401429
SRCS "${SRCS}"
402-
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
430+
CUDA_ARCHS "${SCALED_MM_ARCHS}")
403431
list(APPEND VLLM_EXT_SRC "${SRCS}")
404432
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
405-
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
433+
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
406434
else()
407-
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
435+
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
408436
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
409437
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
410438
"if you intend on running FP8 sparse quantized models on Hopper.")
@@ -432,22 +460,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
432460
set(FP4_ARCHS)
433461
endif()
434462

435-
# FP8 Blackwell Archs
436-
cuda_archs_loose_intersection(BLACKWELL_ARCHS "10.0;10.1;12.0" "${CUDA_ARCHS}")
437-
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND BLACKWELL_ARCHS)
438-
set(SRCS
439-
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
440-
)
441-
set_gencode_flags_for_srcs(
442-
SRCS "${SRCS}"
443-
CUDA_ARCHS "${BLACKWELL_ARCHS}")
444-
list(APPEND VLLM_EXT_SRC "${SRCS}")
445-
message(STATUS "Building FP8 for archs: ${BLACKWELL_ARCHS}")
446-
else()
447-
# clear BLACKWELL_ARCHS
448-
set(BLACKWELL_ARCHS)
449-
endif()
450-
451463
#
452464
# Machete kernels
453465

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include <cudaTypedefs.h>
2+
#include "c3x/scaled_mm_kernels.hpp"
3+
4+
#include "cuda_utils.h"
5+
6+
/*
7+
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
8+
NVIDIA GPUs with sm100 (Blackwell).
9+
*/
10+
11+
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
12+
13+
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
14+
torch::Tensor const& b,
15+
torch::Tensor const& a_scales,
16+
torch::Tensor const& b_scales,
17+
std::optional<torch::Tensor> const& bias) {
18+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
19+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
20+
21+
int M = a.size(0), N = b.size(1), K = a.size(1);
22+
TORCH_CHECK(
23+
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
24+
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
25+
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
26+
27+
// Standard per-tensor/per-token/per-channel scaling
28+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
29+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
30+
"Currently, only fp8 gemm is implemented for Blackwell");
31+
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
32+
}
33+
34+
#endif

csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu renamed to csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
/*
77
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
8-
NVIDIA GPUs with sm90a (Hopper) or later.
8+
NVIDIA GPUs with sm90a (Hopper).
99
*/
1010

11+
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
12+
1113
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
1214
torch::Tensor const& b,
1315
torch::Tensor const& a_scales,
@@ -72,27 +74,4 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
7274
azp, bias);
7375
}
7476

75-
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
76-
77-
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
78-
torch::Tensor const& b,
79-
torch::Tensor const& a_scales,
80-
torch::Tensor const& b_scales,
81-
std::optional<torch::Tensor> const& bias) {
82-
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
83-
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
84-
85-
int M = a.size(0), N = b.size(1), K = a.size(1);
86-
TORCH_CHECK(
87-
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
88-
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
89-
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
90-
91-
// Standard per-tensor/per-token/per-channel scaling
92-
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
93-
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
94-
"Currently, only fp8 gemm is implemented for Blackwell");
95-
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
96-
}
97-
9877
#endif

csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
2323
torch::Tensor const& b_scales,
2424
std::optional<torch::Tensor> const& bias);
2525

26-
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
26+
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
2727
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
2828
torch::Tensor const& b,
2929
torch::Tensor const& a_scales,
3030
torch::Tensor const& b_scales,
3131
std::optional<torch::Tensor> const& bias);
32+
#endif
33+
34+
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
3235
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
3336
torch::Tensor const& b,
3437
torch::Tensor const& a_scales,
@@ -60,7 +63,7 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
6063
std::optional<torch::Tensor> const& azp,
6164
std::optional<torch::Tensor> const& bias);
6265

63-
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
66+
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
6467
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
6568
torch::Tensor const& b,
6669
torch::Tensor const& a_scales,
@@ -121,26 +124,21 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
121124

122125
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
123126
int32_t version_num = get_sm_version_num();
124-
// Hopper
125-
126-
// Guard against compilation issues for sm90 kernels
127-
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
128127

129-
#if defined CUDA_VERSION && CUDA_VERSION < 12080
130-
if (version_num >= 90 && version_num < 100) {
131-
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
128+
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
129+
if (version_num >= 100) {
130+
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
132131
return;
133132
}
134-
#else
133+
#endif
134+
135+
// Guard against compilation issues for sm90 kernels
136+
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
135137
if (version_num >= 90 && version_num < 100) {
138+
// Hopper
136139
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
137140
return;
138-
} else if (version_num >= 100) {
139-
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
140-
return;
141141
}
142-
#endif
143-
144142
#endif
145143

146144
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
@@ -211,7 +209,7 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
211209

212210
int32_t version_num = get_sm_version_num();
213211

214-
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
212+
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
215213
if (version_num >= 90) {
216214
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
217215
return;

0 commit comments

Comments
 (0)