@@ -424,7 +424,7 @@ inline __device__ void convert_dKV(const Params ¶ms) {
424424
425425// //////////////////////////////////////////////////////////////////////////////////////////////////
426426
427- template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false , typename Params>
427+ template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Is_attn_mask, bool Seq_parallel=false , typename Params>
428428inline __device__ void compute_dq_dk_dv_1colblock (const Params ¶ms, const int bidb, const int bidh, const int n_block) {
429429
430430 using Element = typename Kernel_traits::Element;
@@ -448,7 +448,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
448448 const BlockInfo</* Varlen=*/ !Is_even_MN> binfo (params, bidb);
449449 if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0 ) return ;
450450
451- int m_block_max = cute::ceil_div (binfo.actual_seqlen_q , kBlockM );
451+ // umiswing: residue is for predication of additional mask gmem access.
452+ // Additional mask for varlen qkv is supported, but a varlen mask is not supported.
453+ const int m_residue = params.seqlen_q % kBlockM ? params.seqlen_q % kBlockM : kBlockM ;
454+ const int n_residue = params.seqlen_k % kBlockN ? params.seqlen_k % kBlockN : kBlockN ;
455+
456+ const int m_block_max = cute::ceil_div (binfo.actual_seqlen_q , kBlockM );
457+ const int n_block_max = cute::ceil_div (binfo.actual_seqlen_k , kBlockN );
452458
453459 const index_t row_offset_q = binfo.q_offset (params.q_batch_stride , params.q_row_stride , bidb)
454460 + (m_block_max - 1 ) * kBlockM * params.q_row_stride + bidh * params.q_head_stride ;
@@ -469,6 +475,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
469475 const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
470476 + (m_block_max - 1 ) * kBlockM ;
471477
478+ const index_t row_offset_mask = ((bidb * params.mask_head_mod_size
479+ + (bidh % params.mask_head_mod_size )) * params.mask_seq_q_mod_size
480+ + ((m_block_max - 1 ) * kBlockM % params.mask_seq_q_mod_size )) * params.seqlen_k
481+ + n_block * kBlockN ;
482+
472483 Tensor gQ = make_tensor (make_gmem_ptr (reinterpret_cast <Element *>(params.q_ptr ) + row_offset_q),
473484 Shape<Int<kBlockM >, Int<kHeadDim >>{},
474485 make_stride (params.q_row_stride , _1{}));
@@ -494,6 +505,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
494505 Shape<Int<kBlockM >>{}, Stride<_1>{});
495506 Tensor gdPsum = make_tensor (make_gmem_ptr (reinterpret_cast <ElementAccum *>(params.dsoftmax_sum ) + row_offset_dpsum),
496507 Shape<Int<kBlockM >>{}, Stride<_1>{});
508+ Tensor gMask = make_tensor (make_gmem_ptr (reinterpret_cast <Element *>(params.attn_mask_ptr ) + row_offset_mask),
509+ Shape<Int<kBlockM >, Int<kBlockN >>{},
510+ make_stride (params.seqlen_k , _1{}));
497511
498512 Tensor sQ = make_tensor (make_smem_ptr (reinterpret_cast <Element *>(smem_)),
499513 typename Kernel_traits::SmemLayoutQdO{});
@@ -558,6 +572,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
558572 // }
559573
560574 typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
575+ auto gmem_thr_copy_P = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice (tidx);
576+ Tensor tPgMask = gmem_thr_copy_P.partition_D (gMask );
577+ Tensor cMask = make_identity_tensor (shape (gMask ));
578+ Tensor tPcMask = gmem_thr_copy_P.partition_D (cMask);
579+
561580 auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice (tidx);
562581 Tensor tSrQ = thr_mma_sdp.partition_fragment_A (sQ ); // (MMA,MMA_N,MMA_K)
563582 Tensor tSrK = thr_mma_sdp.partition_fragment_B (sK ); // (MMA,MMA_N,MMA_K)
@@ -813,6 +832,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
813832 // However, it's possible that the values in acc_s are so large that they overflow
814833 // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
815834 // So we need to mask out the elements beyond actual_seqlen_k.
835+ if (Is_attn_mask) {
836+ flash::apply_attn_mask<Kernel_traits::TiledMmaSdP>(scores, tPgMask, tPcMask,
837+ m_block == m_block_max - 1 ? m_residue : params.seqlen_q ,
838+ n_block == n_block_max - 1 ? n_residue : params.seqlen_k ,
839+ params.unscale_softmax );
840+ tPgMask.data () = tPgMask.data () + (-kBlockM * params.seqlen_k );
841+ }
816842 if (!Is_causal) {
817843 if (!Is_even_MN && (n_block + 1 ) * kBlockN >= binfo.actual_seqlen_k ) {
818844 flash::apply_mask (scores, binfo.actual_seqlen_k ,
@@ -1550,7 +1576,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) {
15501576
15511577// //////////////////////////////////////////////////////////////////////////////////////////////////
15521578
1553- template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params>
1579+ template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_attn_mask, typename Params>
15541580inline __device__ void compute_dq_dk_dv_seqk_parallel (const Params ¶ms) {
15551581
15561582 const int n_block = blockIdx.x ;
@@ -1562,11 +1588,11 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
15621588 if (params.num_splits == 1 ) { // means grid.x = 1, blockIdx.x = 0;
15631589 int loop_step_x = 0 ;
15641590 for (int i = 0 ; i < params.seqlen_k ; i+= kBlockN ) {
1565- compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false , false , /* Seq_parallel=*/ true >(params, bidb, bidh, loop_step_x);
1591+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false , false , Is_attn_mask, /* Seq_parallel=*/ true >(params, bidb, bidh, loop_step_x);
15661592 loop_step_x += 1 ;
15671593 }
15681594 } else {
1569- compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false , false , /* Seq_parallel=*/ true >(params, bidb, bidh, n_block);
1595+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false , false , Is_attn_mask, /* Seq_parallel=*/ true >(params, bidb, bidh, n_block);
15701596 }
15711597}
15721598
0 commit comments