Skip to content

Commit c3b93bc

Browse files
authored
[BUG FIX] Fix mask/bias memory access and vectorization issues in kernels
2 parents 7a546fe + c426443 commit c3b93bc

File tree

6 files changed

+258
-270
lines changed

6 files changed

+258
-270
lines changed

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 43 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, b
8080
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
8181

8282
using Element = typename Kernel_traits::Element;
83+
using ElementMask = typename Kernel_traits::ElementMask;
8384
using ElementAccum = typename Kernel_traits::ElementAccum;
8485
using index_t = typename Kernel_traits::index_t;
8586

@@ -245,25 +246,29 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
245246
sBias.data(),
246247
typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}
247248
);
248-
Tensor sMask = make_tensor(
249+
Tensor sMaskPlace = make_tensor(
249250
sBias.data() + size(sBias),
250251
typename Kernel_traits::SmemLayoutMaskBiasPdS{}
252+
); // For pointers alignment only
253+
Tensor sMask = make_tensor(
254+
make_smem_ptr(reinterpret_cast<ElementMask *>(sMaskPlace.data().get())),
255+
typename Kernel_traits::SmemLayoutMaskBiasPdS{}
251256
);
252257
Tensor sP = make_tensor(
253-
sMask.data(),
258+
sMaskPlace.data(),
254259
typename Kernel_traits::SmemLayoutMaskBiasPdS{}
255260
);
256261
Tensor sPt = make_tensor(
257-
sMask.data(),
262+
sMaskPlace.data(),
258263
typename Kernel_traits::SmemLayoutPdStransposed{}
259264
);
260265
Tensor sPtNoSwizzle = make_tensor(
261-
sMask.data(),
266+
sMaskPlace.data(),
262267
typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}
263268
);
264269
// sMask, sP and sdQ share the same memory so be careful
265270
Tensor sdQ = make_tensor(
266-
sMask.data(),
271+
sMaskPlace.data(),
267272
typename Kernel_traits::SmemLayoutdQ{}
268273
);
269274

@@ -278,6 +283,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
278283
GmemTiledCopydO gmem_tiled_copy_dO;
279284
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
280285
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
286+
typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias;
287+
auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice(tidx);
281288
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
282289
using GmemLayoutAtomdQaccum = std::conditional_t<
283290
!Seq_parallel,
@@ -300,7 +307,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
300307
Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask);
301308
Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N)
302309
Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias);
303-
Tensor tdBiasgdBias = gmem_thr_copy_Bias.partition_D(gdBias);
310+
Tensor tdBiasgdBias = gmem_thr_copy_dBias.partition_D(gdBias);
304311

305312
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom, AtomNum), ATOM_M, ATOM_N)
306313
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
@@ -350,20 +357,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
350357
// if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); }
351358
Tensor tdPsV = smem_thr_copy_KV.partition_S(sV);
352359

353-
// auto smem_tiled_copy_Mask = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp);
354-
// auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx);
355-
// Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask);
356-
// auto smem_tiled_copy_Bias = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp);
357-
// auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
358-
// Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);
359-
360360
// Partition sP and sdS to match the accumulator partitioning
361361
// This has to be tiled_mma_sdp, not tiled_mma_dkv
362362
// auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
363363
auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
364364
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx);
365-
Tensor tSsMask = smem_thr_copy_PdS.partition_S(sMask);
366-
Tensor tSsBias = smem_thr_copy_PdS.partition_S(sBias);
365+
auto smem_tiled_copy_Mask = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma_sdp);
366+
auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx);
367+
auto smem_tiled_copy_Bias = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma_sdp);
368+
auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
369+
Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask);
370+
Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);
367371
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom, AtomNum), PIPE_M, PIPE_N)
368372
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom, AtomNum), PIPE_M, PIPE_N)
369373

@@ -573,24 +577,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
573577
// // if (cute::thread(1, 0)) { print(tKrK); }
574578

575579
if constexpr (Has_mask) {
576-
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
577-
// gmem_tiled_copy_Mask,
578-
// tMaskgMask, tMasksMask,
579-
// tMaskcMask, tMaskpMask,
580-
// binfo.actual_seqlen_q - m_block * kBlockM
581-
// );
582-
// cute::cp_async_fence();
583-
// FLASH_NAMESPACE::cp_async_wait<0>();
584-
// // Do OR-reduce on the mask to see if any active threads
585-
586-
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
580+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
587581
gmem_tiled_copy_Mask,
588582
tMaskgMask, tMasksMask,
589-
any_active,
590583
tMaskcMask, tMaskpMask,
591584
binfo.actual_seqlen_q - m_block * kBlockM
592585
);
593-
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
586+
__syncthreads();
587+
// Do OR-reduce on the mask to see if any active threads for current interation.
588+
FLASH_NAMESPACE::mask_or_reduce(
589+
tMasksMask,
590+
any_active,
591+
smem_thr_copy_Mask
592+
);
594593
}
595594

596595
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
@@ -602,15 +601,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
602601

603602
if (any_active) {
604603
if constexpr (Has_bias) {
605-
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
604+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
606605
gmem_tiled_copy_Bias,
607606
tBiasgBias, tBiassBias,
608607
tBiascBias, tBiaspBias,
609608
binfo.actual_seqlen_q - m_block * kBlockM
610609
);
611-
// Because copy_bias currently uses scalar loads, we need to sync here.
612-
// TODO: Remove sync after fixing to vectorized loads.
613-
__syncthreads();
614610
}
615611
}
616612

@@ -627,9 +623,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
627623
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
628624
if (Is_first) {
629625
cute::copy(tdOrdO, tdOsdO);
630-
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
626+
dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO>(
631627
tdOrdO, tdOrO, gdPsum,
632-
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow)
628+
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO)
633629
);
634630
}
635631

@@ -852,15 +848,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
852848
__syncthreads();
853849
if constexpr (Has_bias) {
854850
// Write dS to dBias
855-
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/false>(
856-
gmem_tiled_copy_Bias,
851+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
852+
gmem_tiled_copy_dBias,
857853
tBiassBias, tdBiasgdBias,
858854
tBiascBias, tBiaspBias,
859855
binfo.actual_seqlen_q - m_block * kBlockM
860856
);
861-
// Because copy_bias currently uses scalar loads, we need to sync here.
862-
// TODO: Remove sync after fixing to vectorized loads.
863-
__syncthreads();
864857
}
865858

866859
// if (cute::thread0()) { print(tPrP); }
@@ -886,24 +879,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
886879
if constexpr (Has_mask) {
887880
// Advance gMask
888881
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
889-
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
890-
// gmem_tiled_copy_Mask,
891-
// tMaskgMask, tMasksMask,
892-
// tMaskcMask, tMaskpMask,
893-
// binfo.actual_seqlen_q - (m_block - 1) * kBlockM
894-
// );
895-
// FLASH_NAMESPACE::cp_async_fence();
896-
// FLASH_NAMESPACE::cp_async_wait<0>();
897-
// // Do OR-reduce on the mask to see if any active threads for next iteration
898-
899-
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
882+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
900883
gmem_tiled_copy_Mask,
901884
tMaskgMask, tMasksMask,
902-
any_active_next,
903885
tMaskcMask, tMaskpMask,
904886
binfo.actual_seqlen_q - (m_block - 1) * kBlockM
905887
);
906-
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
888+
__syncthreads();
889+
// Do OR-reduce on the mask to see if any active threads for next iteration
890+
FLASH_NAMESPACE::mask_or_reduce(
891+
tMasksMask,
892+
any_active_next,
893+
smem_thr_copy_Mask
894+
);
907895
}
908896

909897
// Advance gdO
@@ -1007,15 +995,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
1007995
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
1008996
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
1009997
if (any_active_next) {
1010-
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
998+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
1011999
gmem_tiled_copy_Bias,
10121000
tBiasgBias, tBiassBias,
10131001
tBiascBias, tBiaspBias,
10141002
binfo.actual_seqlen_q - (m_block - 1) * kBlockM
10151003
);
1016-
// Because copy_bias currently uses scalar loads, we need to sync here.
1017-
// TODO: Remove sync after fixing to vectorized loads.
1018-
__syncthreads();
10191004
}
10201005
}
10211006
}
@@ -1024,9 +1009,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
10241009

10251010
if (Is_first && m_block > m_block_min) {
10261011
cute::copy(tdOrdO, tdOsdO);
1027-
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
1012+
dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO>(
10281013
tdOrdO, tdOrO, gdPsum,
1029-
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow)
1014+
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO)
10301015
);
10311016
}
10321017

csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ inline __device__ void compute_dot_do_o(
148148
tdOcdO, tdOpdO,
149149
binfo.actual_seqlen_q - m_block * kBlockM
150150
);
151-
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
151+
dot_do_o<Kernel_traits::kGmemThreadsPerRowQKVO>(
152152
tdOrdO, tdOrO, dP_sum,
153-
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow)
153+
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRowQKVO)
154154
);
155155
if (Clear_dQaccum) {
156156
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to

0 commit comments

Comments
 (0)