@@ -38,9 +38,9 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, b
3838 #endif
3939}
4040
41- DEFINE_FLASH_FORWARD_KERNEL (flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV ) {
41+ DEFINE_FLASH_FORWARD_KERNEL (flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split) {
4242 #if defined(ARCH_SUPPORTS_FLASH)
43- FLASH_NAMESPACE::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV >(params);
43+ FLASH_NAMESPACE::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Split>(params);
4444 #else
4545 FLASH_UNSUPPORTED_ARCH
4646 #endif
@@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
7474 // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
7575 auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, IsEvenMNConst && IsEvenKConst && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128 , IsEvenKConst && !ReturnSoftmaxConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
7676 // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
77- // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
77+ // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
7878 // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
7979 if (smem_size >= 48 * 1024 ) {
8080 C10_CUDA_CHECK (cudaFuncSetAttribute (
@@ -83,7 +83,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
8383 // int ctas_per_sm;
8484 // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
8585 // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
86- // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
86+ // printf("run_flash_fwd: smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
8787 kernel<<<grid, Kernel_traits::kNThreads , smem_size, stream>>>(params);
8888 C10_CUDA_KERNEL_LAUNCH_CHECK ();
8989 });
@@ -104,26 +104,23 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
104104 BOOL_SWITCH (is_even_MN, IsEvenMNConst, [&] {
105105 EVENK_SWITCH (is_even_K, IsEvenKConst, [&] {
106106 BOOL_SWITCH (params.num_splits > 1 , Split, [&] {
107- BOOL_SWITCH (params.knew_ptr != nullptr , Append_KV, [&] {
108- SOFTCAP_SWITCH (params.softcap > 0.0 , Is_softcap, [&] {
109- // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
110- // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
111- // If Is_local, set Is_causal to false
112- auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && Kernel_traits::kHeadDim <= 128 , IsEvenKConst, Is_softcap, Split, Append_KV>;
113- // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
114- // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
115- // printf("Split = %d, Append_KV = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Append_KV), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap));
116- if (smem_size >= 48 * 1024 ) {
117- C10_CUDA_CHECK (cudaFuncSetAttribute (
118- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
119- }
120- // int ctas_per_sm;
121- // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
122- // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
123- // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
124- kernel<<<grid, Kernel_traits::kNThreads , smem_size, stream>>>(params);
125- C10_CUDA_KERNEL_LAUNCH_CHECK ();
126- });
107+ SOFTCAP_SWITCH (params.softcap > 0.0 , Is_softcap, [&] {
108+ // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
109+ // If Is_local, set Is_causal to false
110+ auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128 , IsEvenKConst, Is_softcap, Split>;
111+ // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split>;
112+ // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
113+ // printf("run_flash_splitkv_fwd: Split = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap));
114+ if (smem_size >= 48 * 1024 ) {
115+ C10_CUDA_CHECK (cudaFuncSetAttribute (
116+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
117+ }
118+ // int ctas_per_sm;
119+ // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
120+ // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
121+ // printf("run_flash_splitkv_fwd: smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
122+ kernel<<<grid, Kernel_traits::kNThreads , smem_size, stream>>>(params);
123+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
127124 });
128125 });
129126 });
@@ -158,9 +155,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
158155template <typename T, int Headdim, bool Is_causal>
159156void run_mha_fwd_splitkv_dispatch (Flash_fwd_params ¶ms, cudaStream_t stream) {
160157 constexpr static int kBlockM = 64 ; // Fixed for all head dimensions
161- // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
162- // and for headdim 192 with block size 64 x 128.
163- constexpr static int kBlockN = Headdim <= 64 ? 128 : (Headdim <= 128 ? 64 : 32 );
158+ constexpr static int kBlockN = Headdim <= 64 ? 64 : (Headdim < 128 ? 64 : 32 );
164159 run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM , kBlockN , 4 , false , false , T>, Is_causal>(params, stream);
165160}
166161
0 commit comments