Skip to content

Commit b74460b

Browse files
authored
Merge pull request #19 from umiswing/fa2_mask
Additional mask support on FA2
2 parents 2e133ca + 9ebf2af commit b74460b

File tree

9 files changed

+314
-135
lines changed

9 files changed

+314
-135
lines changed

csrc/capi/flash_attn.cu

Lines changed: 131 additions & 94 deletions
Large diffs are not rendered by default.

csrc/capi/flash_attn.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,15 @@ bool flash_attn_fwd(const void * const q, // batch_size x seqlen_q x num
2626
const int head_size_rounded,
2727
const float p_dropout,
2828
const float softmax_scale,
29+
const float softmax_unscale,
2930
const bool is_causal,
3031
const bool return_softmax,
3132
const bool is_bf16,
3233
cudaStream_t stream,
3334
uint64_t seed,
34-
uint64_t offset);
35+
uint64_t offset,
36+
const void * const attn_mask,
37+
const int64_t * const mask_dims);
3538

3639
bool flash_attn_varlen_fwd(const void * const q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
3740
const void * const k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
@@ -53,12 +56,15 @@ bool flash_attn_varlen_fwd(const void * const q, // total_q x num_heads x head_
5356
const int head_size_rounded,
5457
const float p_dropout,
5558
const float softmax_scale,
59+
const float softmax_unscale,
5660
const bool is_causal,
5761
const bool return_softmax,
5862
const bool is_bf16,
5963
cudaStream_t stream,
6064
uint64_t seed,
61-
uint64_t offset);
65+
uint64_t offset,
66+
const void * const attn_mask,
67+
const void * const mask_dims);
6268

6369
bool flash_attn_bwd(const void * const dout, // batch_size x seqlen_q x num_heads, x head_size_og
6470
const void * const q, // batch_size x seqlen_q x num_heads x head_size
@@ -83,12 +89,15 @@ bool flash_attn_bwd(const void * const dout, // batch_size x seqlen_q x num_hea
8389
const int head_size_rounded,
8490
const float p_dropout, // probability to drop
8591
const float softmax_scale,
92+
const float softmax_unscale,
8693
const bool is_causal,
8794
const bool is_bf16,
8895
const int num_splits,
8996
cudaStream_t stream,
9097
uint64_t seed,
91-
uint64_t offset);
98+
uint64_t offset,
99+
const void * const attn_mask,
100+
const int64_t * const mask_dims);
92101

93102
bool flash_attn_varlen_bwd(const void * const dout, // total_q x num_heads, x head_size
94103
const void * const q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
@@ -115,12 +124,15 @@ bool flash_attn_varlen_bwd(const void * const dout, // total_q x num_heads, x h
115124
const int head_size_rounded,
116125
const float p_dropout, // probability to drop
117126
const float softmax_scale,
127+
const float softmax_unscale,
118128
const bool is_causal,
119129
const bool is_bf16,
120130
const int num_splits,
121131
cudaStream_t stream,
122132
uint64_t seed,
123-
uint64_t offset);
133+
uint64_t offset,
134+
const void * attn_mask,
135+
const int64_t * const mask_dims);
124136

125137
bool flash_attn_fwd_with_bias_and_mask(const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
126138
const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i

csrc/flash_attn/src/flash.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ struct Flash_fwd_params : public Qkv_params {
7676
// The scaling factors for the kernel.
7777
float scale_softmax;
7878
float scale_softmax_log2;
79+
float unscale_softmax;
7980

8081
// array of length b+1 holding starting offset of each sequence.
8182
int * __restrict__ cu_seqlens_q;
@@ -101,6 +102,11 @@ struct Flash_fwd_params : public Qkv_params {
101102

102103
bool is_bf16;
103104
bool is_causal;
105+
106+
// The attn mask matrix
107+
void * __restrict__ attn_mask_ptr;
108+
int mask_head_mod_size;
109+
int mask_seq_q_mod_size;
104110
};
105111

106112
////////////////////////////////////////////////////////////////////////////////////////////////////

csrc/flash_attn/src/flash_bwd_kernel.h

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ inline __device__ void convert_dKV(const Params &params) {
424424

425425
////////////////////////////////////////////////////////////////////////////////////////////////////
426426

427-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
427+
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Is_attn_mask, bool Seq_parallel=false, typename Params>
428428
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
429429

430430
using Element = typename Kernel_traits::Element;
@@ -448,7 +448,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
448448
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
449449
if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;
450450

451-
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
451+
// umiswing: residue is for predication of additional mask gmem access.
452+
// Additional mask for varlen qkv is supported, but a varlen mask is not supported.
453+
const int m_residue = params.seqlen_q % kBlockM ? params.seqlen_q % kBlockM : kBlockM;
454+
const int n_residue = params.seqlen_k % kBlockN ? params.seqlen_k % kBlockN : kBlockN;
455+
456+
const int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
457+
const int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
452458

453459
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
454460
+ (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
@@ -469,6 +475,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
469475
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
470476
+ (m_block_max - 1) * kBlockM;
471477

478+
const index_t row_offset_mask = ((bidb * params.mask_head_mod_size
479+
+ (bidh % params.mask_head_mod_size)) * params.mask_seq_q_mod_size
480+
+ ((m_block_max - 1) * kBlockM % params.mask_seq_q_mod_size)) * params.seqlen_k
481+
+ n_block * kBlockN;
482+
472483
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
473484
Shape<Int<kBlockM>, Int<kHeadDim>>{},
474485
make_stride(params.q_row_stride, _1{}));
@@ -494,6 +505,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
494505
Shape<Int<kBlockM>>{}, Stride<_1>{});
495506
Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
496507
Shape<Int<kBlockM>>{}, Stride<_1>{});
508+
Tensor gMask = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.attn_mask_ptr) + row_offset_mask),
509+
Shape<Int<kBlockM>, Int<kBlockN>>{},
510+
make_stride(params.seqlen_k, _1{}));
497511

498512
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
499513
typename Kernel_traits::SmemLayoutQdO{});
@@ -558,6 +572,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
558572
// }
559573

560574
typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
575+
auto gmem_thr_copy_P = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx);
576+
Tensor tPgMask = gmem_thr_copy_P.partition_D(gMask);
577+
Tensor cMask = make_identity_tensor(shape(gMask));
578+
Tensor tPcMask = gmem_thr_copy_P.partition_D(cMask);
579+
561580
auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);
562581
Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K)
563582
Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
@@ -813,6 +832,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
813832
// However, it's possible that the values in acc_s are so large that they overflow
814833
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
815834
// So we need to mask out the elements beyond actual_seqlen_k.
835+
if (Is_attn_mask) {
836+
flash::apply_attn_mask<Kernel_traits::TiledMmaSdP>(scores, tPgMask, tPcMask,
837+
m_block == m_block_max - 1 ? m_residue : params.seqlen_q,
838+
n_block == n_block_max - 1 ? n_residue : params.seqlen_k,
839+
params.unscale_softmax);
840+
tPgMask.data() = tPgMask.data() + (-kBlockM * params.seqlen_k);
841+
}
816842
if (!Is_causal) {
817843
if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
818844
flash::apply_mask(scores, binfo.actual_seqlen_k,
@@ -1550,7 +1576,7 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
15501576

15511577
////////////////////////////////////////////////////////////////////////////////////////////////////
15521578

1553-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params>
1579+
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_attn_mask, typename Params>
15541580
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
15551581

15561582
const int n_block = blockIdx.x;
@@ -1562,11 +1588,11 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
15621588
if (params.num_splits == 1) { // means grid.x = 1, blockIdx.x = 0;
15631589
int loop_step_x = 0;
15641590
for(int i = 0; i < params.seqlen_k; i+= kBlockN) {
1565-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, loop_step_x);
1591+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, Is_attn_mask, /*Seq_parallel=*/true>(params, bidb, bidh, loop_step_x);
15661592
loop_step_x += 1;
15671593
}
15681594
} else {
1569-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
1595+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, Is_attn_mask, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
15701596
}
15711597
}
15721598

csrc/flash_attn/src/flash_bwd_launch_template.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
2626
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params);
2727
}
2828

29-
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K>
29+
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_attn_mask>
3030
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
31-
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K>(params);
31+
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Is_attn_mask>(params);
3232
}
3333

3434
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K>
@@ -61,18 +61,21 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
6161
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
6262
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
6363
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
64+
const bool is_attn_mask = params.attn_mask_ptr != nullptr;
6465
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
6566
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
6667
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
6768
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
68-
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst>;
69-
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
70-
if (smem_size_dq_dk_dv >= 48 * 1024) {
71-
C10_CUDA_CHECK(cudaFuncSetAttribute(
72-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
73-
}
74-
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
75-
C10_CUDA_KERNEL_LAUNCH_CHECK();
69+
BOOL_SWITCH(is_attn_mask, Is_attn_mask, [&] {
70+
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst, Is_attn_mask && !IsCausalConst>;
71+
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
72+
if (smem_size_dq_dk_dv >= 48 * 1024) {
73+
C10_CUDA_CHECK(cudaFuncSetAttribute(
74+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
75+
}
76+
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
77+
C10_CUDA_KERNEL_LAUNCH_CHECK();
78+
});
7679
});
7780
});
7881
});

0 commit comments

Comments
 (0)