Skip to content

Commit c426443

Browse files
committed
Corrects mask smem type; unifies copy/reduce
Uses a dedicated mask element type with aligned shared memory, separating mask typing from shared buffers to prevent misalignment and aliasing. Replaces combined mask copy+reduce with a generic copy, explicit barrier, and a separate OR-reduction to ensure accurate activity detection. Unifies bias/mask transfers via generic copy utilities and updates the dot-product threading trait, improving correctness across mixed element types and preparing for varied mask formats.
1 parent d226164 commit c426443

File tree

1 file changed

+33
-38
lines changed

1 file changed

+33
-38
lines changed

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, b
8080
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
8181

8282
using Element = typename Kernel_traits::Element;
83+
using ElementMask = typename Kernel_traits::ElementMask;
8384
using ElementAccum = typename Kernel_traits::ElementAccum;
8485
using index_t = typename Kernel_traits::index_t;
8586

@@ -245,25 +246,29 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
245246
sBias.data(),
246247
typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}
247248
);
248-
Tensor sMask = make_tensor(
249+
Tensor sMaskPlace = make_tensor(
249250
sBias.data() + size(sBias),
250251
typename Kernel_traits::SmemLayoutMaskBiasPdS{}
252+
); // For pointers alignment only
253+
Tensor sMask = make_tensor(
254+
make_smem_ptr(reinterpret_cast<ElementMask *>(sMaskPlace.data().get())),
255+
typename Kernel_traits::SmemLayoutMaskBiasPdS{}
251256
);
252257
Tensor sP = make_tensor(
253-
sMask.data(),
258+
sMaskPlace.data(),
254259
typename Kernel_traits::SmemLayoutMaskBiasPdS{}
255260
);
256261
Tensor sPt = make_tensor(
257-
sMask.data(),
262+
sMaskPlace.data(),
258263
typename Kernel_traits::SmemLayoutPdStransposed{}
259264
);
260265
Tensor sPtNoSwizzle = make_tensor(
261-
sMask.data(),
266+
sMaskPlace.data(),
262267
typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}
263268
);
264269
// sMask, sP and sdQ share the same memory so be careful
265270
Tensor sdQ = make_tensor(
266-
sMask.data(),
271+
sMaskPlace.data(),
267272
typename Kernel_traits::SmemLayoutdQ{}
268273
);
269274

@@ -572,24 +577,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
572577
// // if (cute::thread(1, 0)) { print(tKrK); }
573578

574579
if constexpr (Has_mask) {
575-
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
576-
// gmem_tiled_copy_Mask,
577-
// tMaskgMask, tMasksMask,
578-
// tMaskcMask, tMaskpMask,
579-
// binfo.actual_seqlen_q - m_block * kBlockM
580-
// );
581-
// cute::cp_async_fence();
582-
// FLASH_NAMESPACE::cp_async_wait<0>();
583-
// // Do OR-reduce on the mask to see if any active threads
584-
585-
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
580+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
586581
gmem_tiled_copy_Mask,
587582
tMaskgMask, tMasksMask,
588-
any_active,
589583
tMaskcMask, tMaskpMask,
590584
binfo.actual_seqlen_q - m_block * kBlockM
591585
);
592-
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
586+
__syncthreads();
587+
// Do OR-reduce on the mask to see if any active threads for current interation.
588+
FLASH_NAMESPACE::mask_or_reduce(
589+
tMasksMask,
590+
any_active,
591+
smem_thr_copy_Mask
592+
);
593593
}
594594

595595
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
@@ -601,7 +601,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
601601

602602
if (any_active) {
603603
if constexpr (Has_bias) {
604-
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
604+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
605605
gmem_tiled_copy_Bias,
606606
tBiasgBias, tBiassBias,
607607
tBiascBias, tBiaspBias,
@@ -623,9 +623,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
623623
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
624624
if (Is_first) {
625625
cute::copy(tdOrdO, tdOsdO);
626-
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
626+
dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO>(
627627
tdOrdO, tdOrO, gdPsum,
628-
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow)
628+
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO)
629629
);
630630
}
631631

@@ -848,7 +848,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
848848
__syncthreads();
849849
if constexpr (Has_bias) {
850850
// Write dS to dBias
851-
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/false>(
851+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
852852
gmem_tiled_copy_dBias,
853853
tBiassBias, tdBiasgdBias,
854854
tBiascBias, tBiaspBias,
@@ -879,24 +879,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
879879
if constexpr (Has_mask) {
880880
// Advance gMask
881881
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
882-
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
883-
// gmem_tiled_copy_Mask,
884-
// tMaskgMask, tMasksMask,
885-
// tMaskcMask, tMaskpMask,
886-
// binfo.actual_seqlen_q - (m_block - 1) * kBlockM
887-
// );
888-
// FLASH_NAMESPACE::cp_async_fence();
889-
// FLASH_NAMESPACE::cp_async_wait<0>();
890-
// // Do OR-reduce on the mask to see if any active threads for next iteration
891-
892-
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
882+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
893883
gmem_tiled_copy_Mask,
894884
tMaskgMask, tMasksMask,
895-
any_active_next,
896885
tMaskcMask, tMaskpMask,
897886
binfo.actual_seqlen_q - (m_block - 1) * kBlockM
898887
);
899-
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
888+
__syncthreads();
889+
// Do OR-reduce on the mask to see if any active threads for next iteration
890+
FLASH_NAMESPACE::mask_or_reduce(
891+
tMasksMask,
892+
any_active_next,
893+
smem_thr_copy_Mask
894+
);
900895
}
901896

902897
// Advance gdO
@@ -1000,7 +995,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
1000995
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
1001996
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
1002997
if (any_active_next) {
1003-
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
998+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
1004999
gmem_tiled_copy_Bias,
10051000
tBiasgBias, tBiassBias,
10061001
tBiascBias, tBiaspBias,
@@ -1014,9 +1009,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
10141009

10151010
if (Is_first && m_block > m_block_min) {
10161011
cute::copy(tdOrdO, tdOsdO);
1017-
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
1012+
dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO>(
10181013
tdOrdO, tdOrO, gdPsum,
1019-
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow)
1014+
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO)
10201015
);
10211016
}
10221017

0 commit comments

Comments
 (0)