Skip to content

Commit a04dd1b

Browse files
authored
Merge pull request #68 from SmallDoges:Support-backward
Refactors variable declarations for better readability
2 parents bf4dddc + 6755b4d commit a04dd1b

File tree

3 files changed

+37
-37
lines changed

3 files changed

+37
-37
lines changed

csrc/flash_api.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
232232
) {
233233

234234
// This needs to match with run_mha_fwd_splitkv_dispatch
235-
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
235+
const int block_n = head_size <= 64 ? 64 : (head_size < 128 ? 64 : 32);
236236
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
237237
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
238238
// In any case we don't expect seqlen_q to be larger than 64 for inference.
@@ -259,9 +259,9 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
259259
// See: https://github.com/SmallDoges/flash-dmattn/issues/47
260260
// Regardless of how it is set externally, always set num_splits back to 1.
261261
// This is to avoid the extra memory overhead of Split-KV.
262-
params.num_splits = 1;
263-
softmax_lse_accum.reset();
264-
out_accum.reset();
262+
// params.num_splits = 1;
263+
// softmax_lse_accum.reset();
264+
// out_accum.reset();
265265

266266
return std::make_tuple(softmax_lse_accum, out_accum);
267267
}

csrc/src/flash_fwd_kernel.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -707,8 +707,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
707707
const int n_block_min = n_split_idx * n_blocks_per_split;
708708
int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
709709
if (Is_causal) {
710-
n_block_max = std::min(n_block_max,
711-
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN));
710+
n_block_max = std::min(
711+
n_block_max,
712+
cute::ceil_div(
713+
(m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q,
714+
kBlockN
715+
)
716+
);
712717
}
713718
if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0
714719
// We exit early and write 0 to gOaccum and -inf to gLSEaccum.
@@ -863,9 +868,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
863868
auto thr_mma = tiled_mma.get_thread_slice(tidx);
864869
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K)
865870
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K)
871+
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N)
866872
Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
867873
Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
868-
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N)
869874
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (MMA, MMA_M, MMA_K)
870875

871876
// Copy Atom retiling
@@ -875,15 +880,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
875880
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
876881
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
877882
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
883+
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
884+
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
885+
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
878886
auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
879887
auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx);
880888
Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask);
881889
auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
882890
auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
883891
Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);
884-
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
885-
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
886-
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
887892

888893
// PREDICATES
889894
// Construct identity layout for sQ and sK

csrc/src/flash_fwd_launch_template.h

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, b
3838
#endif
3939
}
4040

41-
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) {
41+
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split) {
4242
#if defined(ARCH_SUPPORTS_FLASH)
43-
FLASH_NAMESPACE::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params);
43+
FLASH_NAMESPACE::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Split>(params);
4444
#else
4545
FLASH_UNSUPPORTED_ARCH
4646
#endif
@@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
7474
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
7575
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, IsEvenMNConst && IsEvenKConst && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !ReturnSoftmaxConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
7676
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
77-
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
77+
// printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
7878
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
7979
if (smem_size >= 48 * 1024) {
8080
C10_CUDA_CHECK(cudaFuncSetAttribute(
@@ -83,7 +83,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
8383
// int ctas_per_sm;
8484
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
8585
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
86-
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
86+
// printf("run_flash_fwd: smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
8787
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
8888
C10_CUDA_KERNEL_LAUNCH_CHECK();
8989
});
@@ -104,26 +104,23 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
104104
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
105105
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
106106
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
107-
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
108-
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
109-
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
110-
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
111-
// If Is_local, set Is_causal to false
112-
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
113-
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
114-
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
115-
// printf("Split = %d, Append_KV = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Append_KV), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap));
116-
if (smem_size >= 48 * 1024) {
117-
C10_CUDA_CHECK(cudaFuncSetAttribute(
118-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
119-
}
120-
// int ctas_per_sm;
121-
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
122-
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
123-
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
124-
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
125-
C10_CUDA_KERNEL_LAUNCH_CHECK();
126-
});
107+
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
108+
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
109+
// If Is_local, set Is_causal to false
110+
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split>;
111+
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split>;
112+
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
113+
// printf("run_flash_splitkv_fwd: Split = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap));
114+
if (smem_size >= 48 * 1024) {
115+
C10_CUDA_CHECK(cudaFuncSetAttribute(
116+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
117+
}
118+
// int ctas_per_sm;
119+
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
120+
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
121+
// printf("run_flash_splitkv_fwd: smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
122+
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
123+
C10_CUDA_KERNEL_LAUNCH_CHECK();
127124
});
128125
});
129126
});
@@ -158,9 +155,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
158155
template<typename T, int Headdim, bool Is_causal>
159156
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
160157
constexpr static int kBlockM = 64; // Fixed for all head dimensions
161-
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
162-
// and for headdim 192 with block size 64 x 128.
163-
constexpr static int kBlockN = Headdim <= 64 ? 128 : (Headdim <= 128 ? 64 : 32);
158+
constexpr static int kBlockN = Headdim <= 64 ? 64 : (Headdim < 128 ? 64 : 32);
164159
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
165160
}
166161

0 commit comments

Comments
 (0)