Skip to content

Commit 994e966

Browse files
Merge remote-tracking branch 'upstream/main' into lwilkinson/aux-fast-api-clean
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
2 parents 2566c79 + 6dbc6e0 commit 994e966

File tree

9 files changed

+357
-57
lines changed

9 files changed

+357
-57
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
251251
hopper/flash_api_torch_lib.cpp
252252
${FA3_GEN_SRCS}
253253
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
254-
ARCHITECTURES ${VLLM_FA_GPU_ARCHES}
254+
ARCHITECTURES "" # LucasW: this is ignored for cuda and set on a per-file basis
255255
USE_SABI 3
256256
WITH_SOABI)
257257

cmake/utils.cmake

Lines changed: 82 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
6262
#
6363
set(SRCS ${ORIG_SRCS})
6464
set(CXX_SRCS ${ORIG_SRCS})
65-
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$")
66-
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$")
65+
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)|(hip)$")
66+
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)|(hip)$")
6767

6868
#
6969
# Generate ROCm/HIP source file names from CUDA file names.
@@ -80,7 +80,7 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
8080
set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc)
8181
add_custom_target(
8282
hipify${NAME}
83-
COMMAND ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS}
83+
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS}
8484
DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
8585
BYPRODUCTS ${HIP_SRCS}
8686
COMMENT "Running hipify on ${NAME} extension source files.")
@@ -232,11 +232,26 @@ macro(set_gencode_flags_for_srcs)
232232
"${multiValueArgs}" ${ARGN} )
233233

234234
foreach(_ARCH ${arg_CUDA_ARCHS})
235-
string(REPLACE "." "" _ARCH "${_ARCH}")
236-
set_gencode_flag_for_srcs(
237-
SRCS ${arg_SRCS}
238-
ARCH "compute_${_ARCH}"
239-
CODE "sm_${_ARCH}")
235+
# handle +PTX suffix: generate both sm and ptx codes if requested
236+
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
237+
if(NOT _HAS_PTX EQUAL -1)
238+
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
239+
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
240+
set_gencode_flag_for_srcs(
241+
SRCS ${arg_SRCS}
242+
ARCH "compute_${_STRIPPED_ARCH}"
243+
CODE "sm_${_STRIPPED_ARCH}")
244+
set_gencode_flag_for_srcs(
245+
SRCS ${arg_SRCS}
246+
ARCH "compute_${_STRIPPED_ARCH}"
247+
CODE "compute_${_STRIPPED_ARCH}")
248+
else()
249+
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
250+
set_gencode_flag_for_srcs(
251+
SRCS ${arg_SRCS}
252+
ARCH "compute_${_STRIPPED_ARCH}"
253+
CODE "sm_${_STRIPPED_ARCH}")
254+
endif()
240255
endforeach()
241256

242257
if (${arg_BUILD_PTX_FOR_ARCH})
@@ -255,15 +270,18 @@ endmacro()
255270
#
256271
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
257272
# `<major>.<minor>[letter]` compute the "loose intersection" with the
258-
# `TGT_CUDA_ARCHS` list of gencodes.
273+
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
274+
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
275+
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
276+
# architecture in `SRC_CUDA_ARCHS`.
259277
# The loose intersection is defined as:
260278
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
261279
# where `<=` is the version comparison operator.
262280
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
263281
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
264-
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
265-
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
266-
# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS).
282+
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
283+
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
284+
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
267285
# The result is stored in `OUT_CUDA_ARCHS`.
268286
#
269287
# Example:
@@ -272,36 +290,63 @@ endmacro()
272290
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
273291
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
274292
#
293+
# Example With PTX:
294+
# SRC_CUDA_ARCHS="8.0+PTX"
295+
# TGT_CUDA_ARCHS="9.0"
296+
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
297+
# OUT_CUDA_ARCHS="8.0+PTX"
298+
#
275299
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
276-
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
277-
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
300+
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
301+
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
302+
303+
# handle +PTX suffix: separate base arch for matching, record PTX requests
304+
set(_PTX_ARCHS)
305+
foreach(_arch ${_SRC_CUDA_ARCHS})
306+
if(_arch MATCHES "\\+PTX$")
307+
string(REPLACE "+PTX" "" _base "${_arch}")
308+
list(APPEND _PTX_ARCHS "${_base}")
309+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
310+
list(APPEND _SRC_CUDA_ARCHS "${_base}")
311+
endif()
312+
endforeach()
313+
list(REMOVE_DUPLICATES _PTX_ARCHS)
314+
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
278315

279-
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
280-
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
316+
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
317+
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
281318
set(_CUDA_ARCHS)
282-
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
283-
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
284-
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
285-
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
319+
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
320+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
321+
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
322+
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
286323
set(_CUDA_ARCHS "9.0a")
287324
endif()
288325
endif()
289326

290-
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
327+
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
328+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
329+
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
330+
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
331+
set(_CUDA_ARCHS "10.0a")
332+
endif()
333+
endif()
334+
335+
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
291336

292337
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
293338
# is less or equal to ARCH (but has the same major version since SASS binary
294339
# compatibility is only forward compatible within the same major version).
295-
foreach(_ARCH ${TGT_CUDA_ARCHS_})
340+
foreach(_ARCH ${_TGT_CUDA_ARCHS})
296341
set(_TMP_ARCH)
297342
# Extract the major version of the target arch
298343
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
299-
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
344+
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
300345
# Extract the major version of the source arch
301346
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
302-
# Check major-version match AND version-less-or-equal
347+
# Check version-less-or-equal, and allow PTX arches to match across majors
303348
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
304-
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
349+
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
305350
set(_TMP_ARCH "${_SRC_ARCH}")
306351
endif()
307352
else()
@@ -317,6 +362,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
317362
endforeach()
318363

319364
list(REMOVE_DUPLICATES _CUDA_ARCHS)
365+
366+
# reapply +PTX suffix to architectures that requested PTX
367+
set(_FINAL_ARCHS)
368+
foreach(_arch ${_CUDA_ARCHS})
369+
if(_arch IN_LIST _PTX_ARCHS)
370+
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
371+
else()
372+
list(APPEND _FINAL_ARCHS "${_arch}")
373+
endif()
374+
endforeach()
375+
set(_CUDA_ARCHS ${_FINAL_ARCHS})
376+
320377
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
321378
endfunction()
322379

csrc/flash_attn/flash_api_sparse.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -157,17 +157,14 @@ mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x hea
157157
std::optional<at::Generator> gen_) {
158158

159159
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
160-
bool is_sm8x = cc_major == 8 && cc_minor >= 0;
161-
bool is_sm90 = cc_major == 9 && cc_minor == 0;
162-
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
163-
// We will support Turing in the near future
164-
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
160+
bool is_sm8x_min = cc_major >= 8;
161+
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
165162

166163
auto q_dtype = q.dtype();
167164
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
168165
"FlashAttention only support fp16 and bf16 data type");
169166
if (q_dtype == torch::kBFloat16) {
170-
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
167+
TORCH_CHECK(is_sm8x_min, "bfloat16 is only supported on Ampere GPUs or newer");
171168
}
172169
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
173170
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
@@ -342,17 +339,14 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
342339
std::optional<at::Generator> gen_) {
343340

344341
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
345-
bool is_sm8x = cc_major == 8 && cc_minor >= 0;
346-
bool is_sm90 = cc_major == 9 && cc_minor == 0;
347-
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
348-
// We will support Turing in the near future
349-
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
342+
bool is_sm8x_min = cc_major >= 8;
343+
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
350344

351345
auto q_dtype = q.dtype();
352346
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
353347
"FlashAttention only support fp16 and bf16 data type");
354348
if (q_dtype == torch::kBFloat16) {
355-
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
349+
TORCH_CHECK(is_sm8x_min, "bfloat16 is only supported on Ampere GPUs or newer");
356350
}
357351
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
358352
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
@@ -528,4 +522,4 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
528522
return {out, softmax_lse};
529523
}
530524

531-
} // namespace FLASH_NAMESPACE
525+
} // namespace FLASH_NAMESPACE

hopper/flash_api_torch_lib.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
5151
std::optional<at::Tensor> &scheduler_metadata_, // (b + 1)
5252
int num_splits,
5353
std::optional<bool> pack_gqa_,
54-
int const sm_margin
54+
int const sm_margin,
55+
std::optional<const at::Tensor> &s_aux_
5556
);
5657

5758
// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
@@ -118,7 +119,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
118119
" Tensor? scheduler_metadata,"
119120
" int num_splits,"
120121
" bool? pack_gqa,"
121-
" int sm_margin) -> Tensor[]");
122+
" int sm_margin,"
123+
" Tensor? s_aux) -> Tensor[]");
122124
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
123125

124126
ops.def("get_scheduler_metadata("

0 commit comments

Comments
 (0)