Skip to content

Commit f9f1dc3

Browse files
remove overlap in device architectures for cutlass_scaled_mm
1 parent 475e57a commit f9f1dc3

File tree

2 files changed

+79
-33
lines changed

2 files changed

+79
-33
lines changed

CMakeLists.txt

+31-11
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
239239
"csrc/quantization/gguf/gguf_kernel.cu"
240240
"csrc/custom_all_reduce.cu"
241241
"csrc/permute_cols.cu"
242-
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
243-
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu")
242+
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
244243

245244
set_gencode_flags_for_srcs(
246245
SRCS "${VLLM_EXT_SRC}"
@@ -268,26 +267,47 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
268267
endif()
269268

270269
#
271-
# The CUTLASS kernels for Hopper require sm90a to be enabled.
272-
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
273-
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
274-
# Only build scaled_mm_c3x if we are building for something compatible with sm90a
275-
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
276-
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
277-
270+
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
271+
# CUDA 12.0 or later (and only works on Hoppper, 9.0/9.0a for now).
272+
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
273+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
278274
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
279275
set_gencode_flags_for_srcs(
280276
SRCS "${SRCS}"
281-
CUDA_ARCHS "${SCALED_MM_ARCHS}")
277+
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
282278
list(APPEND VLLM_EXT_SRC "${SRCS}")
283279
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
284-
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
280+
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
285281
else()
282+
# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
283+
# build any 3x kernels
284+
set(SCALED_MM_3X_ARCHS)
286285
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
287286
"in CUDA target architectures or CUDA Compiler version is "
288287
"not >= 12.0")
289288
endif()
290289

290+
#
291+
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
292+
# kernels for the remaining archs that are not already built for 3x.
293+
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
294+
"7.5;8.0;8.6;8.9;9.0;9.0a" "${CUDA_ARCHS}")
295+
# subtract out the archs that are already built for 3x
296+
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
297+
if (SCALED_MM_2X_ARCHS)
298+
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu")
299+
set_gencode_flags_for_srcs(
300+
SRCS "${SRCS}"
301+
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
302+
list(APPEND VLLM_EXT_SRC "${SRCS}")
303+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
304+
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
305+
else()
306+
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
307+
"in CUDA target architectures (or archs are already built "
308+
"for 3x)")
309+
endif()
310+
291311

292312
#
293313
# Machete kernels

csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu

+48-22
Original file line numberDiff line numberDiff line change
@@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
114114

115115
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
116116
int32_t version_num = get_sm_version_num();
117-
if (version_num >= 90) {
118-
// Hopper
117+
// Hopper
119118

120-
// Guard against compilation issues for sm90 kernels
119+
// Guard against compilation issues for sm90 kernels
121120
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
121+
if (version_num >= 90) {
122122
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
123-
#else
124-
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
123+
return;
124+
}
125125
#endif
126-
} else if (version_num == 89) {
126+
127+
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
128+
if (version_num == 89) {
127129
// Ada Lovelace
128130
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
129-
} else if (version_num >= 80) {
131+
return;
132+
}
133+
134+
if (version_num >= 80) {
130135
// Ampere
131136
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
132-
} else {
133-
// Turing
134-
TORCH_CHECK(version_num >= 75);
135-
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
137+
return;
136138
}
139+
140+
// Turing
141+
TORCH_CHECK(version_num >= 75);
142+
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
143+
#endif
144+
145+
TORCH_CHECK_NOT_IMPLEMENTED(
146+
false,
147+
"No compiled cutlass_scaled_mm for a compute capability less than "
148+
"CUDA device capability: ",
149+
version_num);
137150
}
138151

139152
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
@@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
174187
"currently bias dtype must match output dtype ", c.dtype());
175188

176189
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
190+
177191
int32_t version_num = get_sm_version_num();
178-
if (version_num >= 90) {
179-
// Hopper
180192

181-
// Guard against compilation issues for sm90 kernels
182193
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
194+
if (version_num >= 90) {
183195
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
184-
#else
185-
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
196+
return;
197+
}
186198
#endif
187-
} else if (version_num == 89) {
199+
200+
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
201+
if (version_num == 89) {
188202
// Ada Lovelace
189203
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
190-
} else if (version_num >= 80) {
204+
return;
205+
}
206+
207+
if (version_num >= 80) {
191208
// Ampere
192209
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
193-
} else {
194-
// Turing
195-
TORCH_CHECK(version_num >= 75);
196-
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
210+
return;
197211
}
212+
213+
// Turing
214+
TORCH_CHECK(version_num >= 75);
215+
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
216+
return;
217+
#endif
218+
219+
TORCH_CHECK_NOT_IMPLEMENTED(
220+
false,
221+
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
222+
"CUDA device capability: ",
223+
version_num);
198224
}

0 commit comments

Comments
 (0)