Skip to content

Commit

Permalink
Get flash_attn to compile for CUDA 11.6 linux nightly build (pytorch#…
Browse files Browse the repository at this point in the history
…84941)

This PR only attempts to get this code to compile for all archs so that we can dispatch to it in pytorch#84653
Pull Request resolved: pytorch#84941
Approved by: https://github.com/drisspg, https://github.com/malfet
  • Loading branch information
cpuhrsch authored and pytorchmergebot committed Sep 26, 2022
1 parent 1543532 commit 6a04df3
Show file tree
Hide file tree
Showing 14 changed files with 36 additions and 28 deletions.
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,10 @@ set(BUILD_ONEDNN_GRAPH OFF)
include(cmake/Dependencies.cmake)

# Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now
option(USE_FLASH_ATTENTION "Whether to build the flash_attention kernel for scaled dot product attention" OFF)
cmake_dependent_option(
USE_FLASH_ATTENTION
"Whether to build the flash_attention kernel for scaled dot product attention" ON
"USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
if(USE_FLASH_ATTENTION)
ADD_DEFINITIONS(-DUSE_FLASH_ATTENTION)
ENDIF()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct FMHAEpilogue {
Element, ElementC, /*ElementsPerAccess=*/4, ThreadblockShape, typename WarpMma::Shape,
typename WarpMma::Policy::Operator::Shape, typename OutputTileThreadMap::CompactedThreadMap>;
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
static_assert(WarpTileIterator::kIterations == kIterationsStore);
static_assert(WarpTileIterator::kIterations == kIterationsStore, "");
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
using OutputFragment = typename SharedLoadIterator::Fragment;

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/transformers/cuda/flash_attn/fmha.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ struct Launch_params{

////////////////////////////////////////////////////////////////////////////////////////////////////

void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
TORCH_API void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
*
******************************************************************************/

#ifdef USE_FLASH_ATTENTION
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/NativeFunctions.h>
Expand Down Expand Up @@ -241,3 +242,4 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
return result;
}
} // namespace fmha
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

namespace fmha {

TORCH_API
std::vector<at::Tensor>
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {

static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
// If V is stored in shared memory, we can't load K using the same shared memory.
static_assert(Kernel_traits::V_IN_REGS);
static_assert(Kernel_traits::V_IN_REGS, "");

static constexpr size_t SMEM_OFFSET_O = Kernel_traits::BYTES_PER_SMEM_Q;
static constexpr size_t SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + sizeof(typename Smem_O::SharedStorage);
Expand Down Expand Up @@ -161,7 +161,7 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {

static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS;
static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V);
static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V, "");

static constexpr size_t SMEM_OFFSET_V = Kernel_traits::BYTES_PER_SMEM_Q + (SHARE_SMEM_FOR_K_AND_V ? 0 : Kernel_traits::BYTES_PER_SMEM_K);
static constexpr size_t SMEM_OFFSET_O = SMEM_OFFSET_V + Kernel_traits::BYTES_PER_SMEM_V;
Expand Down Expand Up @@ -298,7 +298,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);

// Wind gmem tiles to the correct position.
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0, "");
const int begin_og = begin;
begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin;
const int steps_og = steps;
Expand Down Expand Up @@ -428,7 +428,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
const int warp_idx = threadIdx.x / 32;
iter_V.add_tile_offset({kIterationsPV * warp_idx, 0});
typename WarpIteratorV::Fragment frag_v[kIterationsPV];
static_assert(WarpIteratorV::Fragment::kStorageElements == 4 * Mma_tile_o::MMAS_N || WarpIteratorV::Fragment::kStorageElements == 2 * Mma_tile_o::MMAS_N );
static_assert(WarpIteratorV::Fragment::kStorageElements == 4 * Mma_tile_o::MMAS_N || WarpIteratorV::Fragment::kStorageElements == 2 * Mma_tile_o::MMAS_N, "");
#pragma unroll
for( int ki = 0; ki < kIterationsPV; ++ki ) {
iter_V.load(frag_v[ki]);
Expand Down Expand Up @@ -463,8 +463,8 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
gemm_q_k(mma_qk, acc_p);

typename Smem_O::OutputFragment out[Smem_O::kIterationsStore];
static_assert(GmemIteratorOAccum::kIterations == Smem_O::kIterationsStore);
static_assert(GmemIteratorO::kIterations == Smem_O::kIterationsStore);
static_assert(GmemIteratorOAccum::kIterations == Smem_O::kIterationsStore, "");
static_assert(GmemIteratorO::kIterations == Smem_O::kIterationsStore, "");
if (!Is_first) {
#pragma unroll
for (int iter = 0; iter < GmemIteratorOAccum::kIterations; ++iter) {
Expand Down Expand Up @@ -536,8 +536,8 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint16_t);
}

static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M, "");
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N, "");
softmax.pack_noconvert(acc_p);
cutlass::NumericArrayConverter<Element, ElementAccum, decltype(acc_p)::kElements, cutlass::FloatRoundStyle::round_to_nearest> convert_p;
auto frag_p = convert_p(acc_p);
Expand All @@ -558,13 +558,13 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Declare the accumulators for the 2nd gemm.
WarpMmaPV mma_pv;
typename WarpMmaPV::FragmentC acc_o;
static_assert(WarpMmaPV::FragmentC::kElements == Mma_tile_o::MMAS_M * Mma_tile_o::MMAS_N * 8);
static_assert(WarpMmaPV::FragmentC::kElements == Mma_tile_o::MMAS_M * Mma_tile_o::MMAS_N * 8, "");
acc_o.clear();

// For some reason, WarpMmaPV::FragmentA has length K * N * (8|4) instead of just N * (8|4).
// We have to first cast frag_p to be array of k x (N * (8|4)), then cast each row to be
// an array of WarpMmaPV::FragmentA (which is what mma_pv expects).
static_assert(decltype(frag_p)::kElements == kIterationsPV * Mma_tile_o::MMAS_M * WarpMmaPV::FragmentA::kElements);
static_assert(decltype(frag_p)::kElements == kIterationsPV * Mma_tile_o::MMAS_M * WarpMmaPV::FragmentA::kElements, "");
const auto frag_p_reshaped = reinterpret_cast<const cutlass::Array<Element, WarpMmaPV::FragmentA::kElements> (&)[kIterationsPV]>(frag_p);
#pragma unroll
for( int ki = 0; ki < kIterationsPV; ++ki ) {
Expand All @@ -589,7 +589,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
}

softmax.reduce_max_after_sync_(p_max_o, rows);
static_assert(Mma_tile_o::MMAS_M == 1);
static_assert(Mma_tile_o::MMAS_M == 1, "");
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
p_max_o[jj][0] *= params.scale_bmm1;
}
Expand All @@ -601,7 +601,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Make sure the data is in shared memory.
__syncthreads();

static_assert(Mma_tile_o::MMAS_M == 1);
static_assert(Mma_tile_o::MMAS_M == 1, "");
float p_sum_o[kOutputRowsPerThread][Mma_tile_o::MMAS_M];
softmax.reduce_sum_after_sync_(p_sum_o, rows);
if (!Is_first) {
Expand All @@ -625,7 +625,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i

// Load from shared memory.
using ArrayTypeO = cutlass::Array<ElementAccum, OutputTileThreadMap::kElementsPerAccess>;
static_assert(OutputTileThreadMap::kElementsPerAccess * kOutputRowsPerThread == Smem_O::kIterationsStore * Smem_O::OutputFragment::kElements);
static_assert(OutputTileThreadMap::kElementsPerAccess * kOutputRowsPerThread == Smem_O::kIterationsStore * Smem_O::OutputFragment::kElements, "");
cutlass::multiplies<ArrayTypeO> multiply_fragments;
if (!Is_first) {
auto out_reshaped = reinterpret_cast<ArrayTypeO (&)[kOutputRowsPerThread]>(out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void run_fmha_loop_(Launch_params<FMHA_fprop_params> &launch_params,
});
}

void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params,
TORCH_API void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) {
BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] {
using elem_type = std::conditional<IsBf16Const, cutlass::bfloat16_t, cutlass::half_t>::type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ struct Gmem_tile_mma_s : public Base {
// Store to global memory.
template<typename Mask, typename Fragment>
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
static_assert(Fragment::kStorageElements == 4);
static_assert(Fragment::kStorageElements == 4, "");
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct FMHA_kernel_traits {
#endif
using ElementAccum = float;

static_assert(WARPS_M == 1);
static_assert(WARPS_M == 1, "");
using ThreadblockShapeQK = cutlass::gemm::GemmShape<STEP, S, D>;
using WarpCountQK = cutlass::gemm::GemmShape<WARPS_M, WARPS_N, 1>;
using WarpShapeQK = cutlass::gemm::GemmShape<
Expand Down Expand Up @@ -144,7 +144,7 @@ struct FMHA_kernel_traits {
static constexpr size_t BYTES_PER_SMEM_Q = ThreadblockShapeQK::kM * ThreadblockShapeQK::kK * sizeof(Element);
static constexpr size_t BYTES_PER_SMEM_K = ThreadblockShapeQK::kN * ThreadblockShapeQK::kK * sizeof(Element);
static constexpr size_t BYTES_PER_SMEM_V = ThreadblockShapePV::kN * ThreadblockShapePV::kK * sizeof(Element);
static_assert(BYTES_PER_SMEM_K == BYTES_PER_SMEM_V);
static_assert(BYTES_PER_SMEM_K == BYTES_PER_SMEM_V, "");
static constexpr size_t BYTES_PER_SMEM_QK = BYTES_PER_SMEM_Q + BYTES_PER_SMEM_K;
// The extra amount of shared memory needed to load V.
static constexpr size_t BYTES_PER_SMEM_V_EXTRA = SHARE_SMEM_FOR_K_AND_V ? 0u : BYTES_PER_SMEM_V;
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct Smem_tile_reduce {

static constexpr int ROWS = WARPS_M * MMAS_M * 16;
static constexpr int COLS = WARPS_N;
static_assert(COLS == 4 || COLS == 8);
static_assert(COLS == 4 || COLS == 8, "");
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
static constexpr int ELTS_PER_TILE = ROWS * COLS;
Expand Down Expand Up @@ -263,7 +263,7 @@ struct Softmax_base {
};
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
static_assert(MMAS_N % 2 == 0);
static_assert(MMAS_N % 2 == 0, "");
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni += 2 ) {
uint4 random_uint4 = ph0();
Expand Down Expand Up @@ -319,7 +319,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
static constexpr int MMAS_N = Base::MMAS_N;

using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N, "");
// Ctor.
template<typename Params>
inline __device__ Softmax(const Params &params, void *smem, int tidx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ template<int kRows, int kRowsPerMma, int kWarpCountM>
struct Smem_tile_softmax_lse {

static constexpr int kMmaM = (kRows / kWarpCountM) / kRowsPerMma;
static_assert(kMmaM * kRowsPerMma * kWarpCountM == kRows);
static_assert(kMmaM * kRowsPerMma * kWarpCountM == kRows, "");
// static_assert(kWarpCountM == 1);
// Otherwise we might need to check warp_idx / kWarpCountM == 0 instead of just warp_idx == 0

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/transformers/cuda/flash_attn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ __device__ inline T operator()(T const & x, T const & y) { return x + y; }

template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4, "");
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
Expand Down
7 changes: 4 additions & 3 deletions aten/src/ATen/native/transformers/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,14 @@ Tensor transformer_encoder_layer_forward(
if (norm_first) {
x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor);
}
#if USE_FLASH_ATTENTION

#if BETTER_TRANSFORMER_USE_FLASH_ATTENTION
if (x.is_nested() && x.is_cuda() && x.dtype() == at::kHalf && !mask.has_value() &&
(embed_dim / num_heads == 16 ||
embed_dim / num_heads == 32 ||
embed_dim / num_heads == 64 ||
embed_dim / num_heads == 128)) {
TORCH_WARN_ONCE("USING FLASH ATTENTION WITH NT");
TORCH_WARN_ONCE("transformer_encoder_layer_forward is using flash attention.");
x = at::linear(x, qkv_weight, qkv_bias);
x = x.view({x.size(0), -1, 3, num_heads, embed_dim / num_heads});
x = flash_attention_helper(x, x, x, 0.0, false);
Expand All @@ -135,7 +136,7 @@ Tensor transformer_encoder_layer_forward(
false /* need_weights */,
true /* average_attn_weights */,
mask_type));
#if USE_FLASH_ATTENTION
#if BETTER_TRANSFORMER_USE_FLASH_ATTENTION
}
#endif
add_in_place(x, src, use_nested_tensor);
Expand Down
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,7 @@ aten_cuda_cu_source_list = [
"aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp",
"aten/src/ATen/native/sparse/cuda/SparseBlasLegacy.cpp",
"aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp",
"aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp",
]

# Files using thrust::sort_by_key need to be linked last
Expand Down

0 comments on commit 6a04df3

Please sign in to comment.