@@ -80,6 +80,7 @@ template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, b
8080inline __device__ void compute_dq_dk_dv_1colblock (const Params ¶ms, const int bidb, const int bidh, const int n_block) {
8181
8282 using Element = typename Kernel_traits::Element;
83+ using ElementMask = typename Kernel_traits::ElementMask;
8384 using ElementAccum = typename Kernel_traits::ElementAccum;
8485 using index_t = typename Kernel_traits::index_t ;
8586
@@ -245,25 +246,29 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
245246 sBias .data (),
246247 typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}
247248 );
248- Tensor sMask = make_tensor (
249+ Tensor sMaskPlace = make_tensor (
249250 sBias .data () + size (sBias ),
250251 typename Kernel_traits::SmemLayoutMaskBiasPdS{}
252+ ); // For pointers alignment only
253+ Tensor sMask = make_tensor (
254+ make_smem_ptr (reinterpret_cast <ElementMask *>(sMaskPlace .data ().get ())),
255+ typename Kernel_traits::SmemLayoutMaskBiasPdS{}
251256 );
252257 Tensor sP = make_tensor (
253- sMask .data (),
258+ sMaskPlace .data (),
254259 typename Kernel_traits::SmemLayoutMaskBiasPdS{}
255260 );
256261 Tensor sPt = make_tensor (
257- sMask .data (),
262+ sMaskPlace .data (),
258263 typename Kernel_traits::SmemLayoutPdStransposed{}
259264 );
260265 Tensor sPtNoSwizzle = make_tensor (
261- sMask .data (),
266+ sMaskPlace .data (),
262267 typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}
263268 );
264269 // sMask, sP and sdQ share the same memory so be careful
265270 Tensor sdQ = make_tensor (
266- sMask .data (),
271+ sMaskPlace .data (),
267272 typename Kernel_traits::SmemLayoutdQ{}
268273 );
269274
@@ -572,24 +577,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
572577 // // if (cute::thread(1, 0)) { print(tKrK); }
573578
574579 if constexpr (Has_mask) {
575- // FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
576- // gmem_tiled_copy_Mask,
577- // tMaskgMask, tMasksMask,
578- // tMaskcMask, tMaskpMask,
579- // binfo.actual_seqlen_q - m_block * kBlockM
580- // );
581- // cute::cp_async_fence();
582- // FLASH_NAMESPACE::cp_async_wait<0>();
583- // // Do OR-reduce on the mask to see if any active threads
584-
585- FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /* Clear_OOB_MN=*/ true , /* To_type=*/ Element>(
580+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
586581 gmem_tiled_copy_Mask,
587582 tMaskgMask, tMasksMask,
588- any_active,
589583 tMaskcMask, tMaskpMask,
590584 binfo.actual_seqlen_q - m_block * kBlockM
591585 );
592- // We don't need to syncthreads here because copy_mask is already or_syncthreads.
586+ __syncthreads ();
587+ // Do OR-reduce on the mask to see if any active threads for current interation.
588+ FLASH_NAMESPACE::mask_or_reduce (
589+ tMasksMask,
590+ any_active,
591+ smem_thr_copy_Mask
592+ );
593593 }
594594
595595 FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /* Clear_OOB_MN=*/ true >(
@@ -601,7 +601,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
601601
602602 if (any_active) {
603603 if constexpr (Has_bias) {
604- FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ true >(
604+ FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ true >(
605605 gmem_tiled_copy_Bias,
606606 tBiasgBias, tBiassBias,
607607 tBiascBias, tBiaspBias,
@@ -623,9 +623,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
623623 // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
624624 if (Is_first) {
625625 cute::copy (tdOrdO, tdOsdO);
626- dot_do_o<Kernel_traits::kGmemThreadsPerRow >(
626+ dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO >(
627627 tdOrdO, tdOrO, gdPsum,
628- Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow )
628+ Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO )
629629 );
630630 }
631631
@@ -848,7 +848,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
848848 __syncthreads ();
849849 if constexpr (Has_bias) {
850850 // Write dS to dBias
851- FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ false >(
851+ FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ false >(
852852 gmem_tiled_copy_dBias,
853853 tBiassBias, tdBiasgdBias,
854854 tBiascBias, tBiaspBias,
@@ -879,24 +879,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
879879 if constexpr (Has_mask) {
880880 // Advance gMask
881881 tMaskgMask.data () = tMaskgMask.data () + (-int (kBlockM * params.mask_row_stride ));
882- // FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
883- // gmem_tiled_copy_Mask,
884- // tMaskgMask, tMasksMask,
885- // tMaskcMask, tMaskpMask,
886- // binfo.actual_seqlen_q - (m_block - 1) * kBlockM
887- // );
888- // FLASH_NAMESPACE::cp_async_fence();
889- // FLASH_NAMESPACE::cp_async_wait<0>();
890- // // Do OR-reduce on the mask to see if any active threads for next iteration
891-
892- FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /* Clear_OOB_MN=*/ true , /* To_type=*/ Element>(
882+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
893883 gmem_tiled_copy_Mask,
894884 tMaskgMask, tMasksMask,
895- any_active_next,
896885 tMaskcMask, tMaskpMask,
897886 binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM
898887 );
899- // We don't need to syncthreads here because copy_mask is already or_syncthreads.
888+ __syncthreads ();
889+ // Do OR-reduce on the mask to see if any active threads for next iteration
890+ FLASH_NAMESPACE::mask_or_reduce (
891+ tMasksMask,
892+ any_active_next,
893+ smem_thr_copy_Mask
894+ );
900895 }
901896
902897 // Advance gdO
@@ -1000,7 +995,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
1000995 tBiasgBias.data () = tBiasgBias.data () + (-int (kBlockM * params.bias_row_stride ));
1001996 tdBiasgdBias.data () = tdBiasgdBias.data () + (-int (kBlockM * params.dbias_row_stride ));
1002997 if (any_active_next) {
1003- FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ true >(
998+ FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ true >(
1004999 gmem_tiled_copy_Bias,
10051000 tBiasgBias, tBiassBias,
10061001 tBiascBias, tBiaspBias,
@@ -1014,9 +1009,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
10141009
10151010 if (Is_first && m_block > m_block_min) {
10161011 cute::copy (tdOrdO, tdOsdO);
1017- dot_do_o<Kernel_traits::kGmemThreadsPerRow >(
1012+ dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO >(
10181013 tdOrdO, tdOrO, gdPsum,
1019- Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow )
1014+ Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO )
10201015 );
10211016 }
10221017
0 commit comments