Skip to content

[Build/BugFix] Fix hopper 12.8 build #14354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 45 additions & 33 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -333,36 +333,64 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures, or CUDA not >= 12.0")
endif()


set(SCALED_MM_3X_ARCHS)
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
# CUDA 12.0 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()

# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
# build any 3x kernels
set(SCALED_MM_3X_ARCHS)
# The cutlass_scaled_mm kernels for Blackwell (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
"later if you intend on running FP8 quantized models on "
"Blackwell.")
else()
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
"in CUDA target architectures")
endif()
endif()

#
Expand Down Expand Up @@ -395,16 +423,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")

# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper and Blackwell).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
"if you intend on running FP8 sparse quantized models on Hopper.")
Expand Down Expand Up @@ -432,22 +460,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(FP4_ARCHS)
endif()

# FP8 Blackwell Archs
cuda_archs_loose_intersection(BLACKWELL_ARCHS "10.0;10.1;12.0" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND BLACKWELL_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${BLACKWELL_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building FP8 for archs: ${BLACKWELL_ARCHS}")
else()
# clear BLACKWELL_ARCHS
set(BLACKWELL_ARCHS)
endif()

#
# Machete kernels

Expand Down
34 changes: 34 additions & 0 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include <cudaTypedefs.h>
#include "c3x/scaled_mm_kernels.hpp"

#include "cuda_utils.h"

/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm100 (Blackwell).
*/

#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100

void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

int M = a.size(0), N = b.size(1), K = a.size(1);
TORCH_CHECK(
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell");

// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
"Currently, only fp8 gemm is implemented for Blackwell");
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
NVIDIA GPUs with sm90a (Hopper).
*/

#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90

void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand Down Expand Up @@ -72,27 +74,4 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
azp, bias);
}

#if defined CUDA_VERSION && CUDA_VERSION >= 12080

void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

int M = a.size(0), N = b.size(1), K = a.size(1);
TORCH_CHECK(
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell");

// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
"Currently, only fp8 gemm is implemented for Blackwell");
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
}

#endif
30 changes: 14 additions & 16 deletions csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);

#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
#endif

#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand Down Expand Up @@ -60,7 +63,7 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand Down Expand Up @@ -121,26 +124,21 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,

at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
// Hopper

// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X

#if defined CUDA_VERSION && CUDA_VERSION < 12080
if (version_num >= 90 && version_num < 100) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
if (version_num >= 100) {
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
return;
}
#else
#endif

// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90 && version_num < 100) {
// Hopper
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
} else if (version_num >= 100) {
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
return;
}
#endif

#endif

#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
Expand Down Expand Up @@ -211,7 +209,7 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,

int32_t version_num = get_sm_version_num();

#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
Expand Down