@@ -80,6 +80,7 @@ template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, b
8080inline __device__ void compute_dq_dk_dv_1colblock (const Params ¶ms, const int bidb, const int bidh, const int n_block) {
8181
8282 using Element = typename Kernel_traits::Element;
83+ using ElementMask = typename Kernel_traits::ElementMask;
8384 using ElementAccum = typename Kernel_traits::ElementAccum;
8485 using index_t = typename Kernel_traits::index_t ;
8586
@@ -245,25 +246,29 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
245246 sBias .data (),
246247 typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}
247248 );
248- Tensor sMask = make_tensor (
249+ Tensor sMaskPlace = make_tensor (
249250 sBias .data () + size (sBias ),
250251 typename Kernel_traits::SmemLayoutMaskBiasPdS{}
252+ ); // For pointers alignment only
253+ Tensor sMask = make_tensor (
254+ make_smem_ptr (reinterpret_cast <ElementMask *>(sMaskPlace .data ().get ())),
255+ typename Kernel_traits::SmemLayoutMaskBiasPdS{}
251256 );
252257 Tensor sP = make_tensor (
253- sMask .data (),
258+ sMaskPlace .data (),
254259 typename Kernel_traits::SmemLayoutMaskBiasPdS{}
255260 );
256261 Tensor sPt = make_tensor (
257- sMask .data (),
262+ sMaskPlace .data (),
258263 typename Kernel_traits::SmemLayoutPdStransposed{}
259264 );
260265 Tensor sPtNoSwizzle = make_tensor (
261- sMask .data (),
266+ sMaskPlace .data (),
262267 typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}
263268 );
264269 // sMask, sP and sdQ share the same memory so be careful
265270 Tensor sdQ = make_tensor (
266- sMask .data (),
271+ sMaskPlace .data (),
267272 typename Kernel_traits::SmemLayoutdQ{}
268273 );
269274
@@ -278,6 +283,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
278283 GmemTiledCopydO gmem_tiled_copy_dO;
279284 auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice (tidx);
280285 typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
286+ typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias;
287+ auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice (tidx);
281288 auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice (tidx);
282289 using GmemLayoutAtomdQaccum = std::conditional_t <
283290 !Seq_parallel,
@@ -300,7 +307,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
300307 Tensor tMasksMask = gmem_thr_copy_Mask.partition_D (sMask );
301308 Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S (gBias ); // (BiasCPY, BiasCPY_M, BiasCPY_N)
302309 Tensor tBiassBias = gmem_thr_copy_Bias.partition_D (sBias );
303- Tensor tdBiasgdBias = gmem_thr_copy_Bias .partition_D (gdBias);
310+ Tensor tdBiasgdBias = gmem_thr_copy_dBias .partition_D (gdBias);
304311
305312 Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S (sdQ); // ((Atom, AtomNum), ATOM_M, ATOM_N)
306313 Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D (gdQ);
@@ -350,20 +357,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
350357 // if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); }
351358 Tensor tdPsV = smem_thr_copy_KV.partition_S (sV );
352359
353- // auto smem_tiled_copy_Mask = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp);
354- // auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx);
355- // Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask);
356- // auto smem_tiled_copy_Bias = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp);
357- // auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
358- // Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);
359-
360360 // Partition sP and sdS to match the accumulator partitioning
361361 // This has to be tiled_mma_sdp, not tiled_mma_dkv
362362 // auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
363363 auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
364364 auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice (tidx);
365- Tensor tSsMask = smem_thr_copy_PdS.partition_S (sMask );
366- Tensor tSsBias = smem_thr_copy_PdS.partition_S (sBias );
365+ auto smem_tiled_copy_Mask = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp);
366+ auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice (tidx);
367+ auto smem_tiled_copy_Bias = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp);
368+ auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice (tidx);
369+ Tensor tSsMask = smem_thr_copy_Mask.partition_S (sMask );
370+ Tensor tSsBias = smem_thr_copy_Bias.partition_S (sBias );
367371 Tensor tPsP = smem_thr_copy_PdS.partition_D (sP ); // ((Atom, AtomNum), PIPE_M, PIPE_N)
368372 Tensor tdSsdS = smem_thr_copy_PdS.partition_D (sdS); // ((Atom, AtomNum), PIPE_M, PIPE_N)
369373
@@ -573,24 +577,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
573577 // // if (cute::thread(1, 0)) { print(tKrK); }
574578
575579 if constexpr (Has_mask) {
576- // FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
577- // gmem_tiled_copy_Mask,
578- // tMaskgMask, tMasksMask,
579- // tMaskcMask, tMaskpMask,
580- // binfo.actual_seqlen_q - m_block * kBlockM
581- // );
582- // cute::cp_async_fence();
583- // FLASH_NAMESPACE::cp_async_wait<0>();
584- // // Do OR-reduce on the mask to see if any active threads
585-
586- FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /* Clear_OOB_MN=*/ true , /* To_type=*/ Element>(
580+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
587581 gmem_tiled_copy_Mask,
588582 tMaskgMask, tMasksMask,
589- any_active,
590583 tMaskcMask, tMaskpMask,
591584 binfo.actual_seqlen_q - m_block * kBlockM
592585 );
593- // We don't need to syncthreads here because copy_mask is already or_syncthreads.
586+ __syncthreads ();
587+ // Do OR-reduce on the mask to see if any active threads for current interation.
588+ FLASH_NAMESPACE::mask_or_reduce (
589+ tMasksMask,
590+ any_active,
591+ smem_thr_copy_Mask
592+ );
594593 }
595594
596595 FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /* Clear_OOB_MN=*/ true >(
@@ -602,15 +601,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
602601
603602 if (any_active) {
604603 if constexpr (Has_bias) {
605- FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ true >(
604+ FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ true >(
606605 gmem_tiled_copy_Bias,
607606 tBiasgBias, tBiassBias,
608607 tBiascBias, tBiaspBias,
609608 binfo.actual_seqlen_q - m_block * kBlockM
610609 );
611- // Because copy_bias currently uses scalar loads, we need to sync here.
612- // TODO: Remove sync after fixing to vectorized loads.
613- __syncthreads ();
614610 }
615611 }
616612
@@ -627,9 +623,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
627623 // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
628624 if (Is_first) {
629625 cute::copy (tdOrdO, tdOsdO);
630- dot_do_o<Kernel_traits::kGmemThreadsPerRow >(
626+ dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO >(
631627 tdOrdO, tdOrO, gdPsum,
632- Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow )
628+ Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO )
633629 );
634630 }
635631
@@ -852,15 +848,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
852848 __syncthreads ();
853849 if constexpr (Has_bias) {
854850 // Write dS to dBias
855- FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ false >(
856- gmem_tiled_copy_Bias ,
851+ FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ false >(
852+ gmem_tiled_copy_dBias ,
857853 tBiassBias, tdBiasgdBias,
858854 tBiascBias, tBiaspBias,
859855 binfo.actual_seqlen_q - m_block * kBlockM
860856 );
861- // Because copy_bias currently uses scalar loads, we need to sync here.
862- // TODO: Remove sync after fixing to vectorized loads.
863- __syncthreads ();
864857 }
865858
866859 // if (cute::thread0()) { print(tPrP); }
@@ -886,24 +879,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
886879 if constexpr (Has_mask) {
887880 // Advance gMask
888881 tMaskgMask.data () = tMaskgMask.data () + (-int (kBlockM * params.mask_row_stride ));
889- // FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
890- // gmem_tiled_copy_Mask,
891- // tMaskgMask, tMasksMask,
892- // tMaskcMask, tMaskpMask,
893- // binfo.actual_seqlen_q - (m_block - 1) * kBlockM
894- // );
895- // FLASH_NAMESPACE::cp_async_fence();
896- // FLASH_NAMESPACE::cp_async_wait<0>();
897- // // Do OR-reduce on the mask to see if any active threads for next iteration
898-
899- FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /* Clear_OOB_MN=*/ true , /* To_type=*/ Element>(
882+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
900883 gmem_tiled_copy_Mask,
901884 tMaskgMask, tMasksMask,
902- any_active_next,
903885 tMaskcMask, tMaskpMask,
904886 binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM
905887 );
906- // We don't need to syncthreads here because copy_mask is already or_syncthreads.
888+ __syncthreads ();
889+ // Do OR-reduce on the mask to see if any active threads for next iteration
890+ FLASH_NAMESPACE::mask_or_reduce (
891+ tMasksMask,
892+ any_active_next,
893+ smem_thr_copy_Mask
894+ );
907895 }
908896
909897 // Advance gdO
@@ -1007,15 +995,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
1007995 tBiasgBias.data () = tBiasgBias.data () + (-int (kBlockM * params.bias_row_stride ));
1008996 tdBiasgdBias.data () = tdBiasgdBias.data () + (-int (kBlockM * params.dbias_row_stride ));
1009997 if (any_active_next) {
1010- FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ true >(
998+ FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ true >(
1011999 gmem_tiled_copy_Bias,
10121000 tBiasgBias, tBiassBias,
10131001 tBiascBias, tBiaspBias,
10141002 binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM
10151003 );
1016- // Because copy_bias currently uses scalar loads, we need to sync here.
1017- // TODO: Remove sync after fixing to vectorized loads.
1018- __syncthreads ();
10191004 }
10201005 }
10211006 }
@@ -1024,9 +1009,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
10241009
10251010 if (Is_first && m_block > m_block_min) {
10261011 cute::copy (tdOrdO, tdOsdO);
1027- dot_do_o<Kernel_traits::kGmemThreadsPerRow >(
1012+ dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO >(
10281013 tdOrdO, tdOrO, gdPsum,
1029- Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow )
1014+ Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO )
10301015 );
10311016 }
10321017
0 commit comments