@@ -647,9 +647,7 @@ static __global__ void flash_attn_stream_k_fixup(
647647}
648648
649649template <int  D> //  D == head size
650- #if  !defined(GGML_USE_HIP)
651650__launch_bounds__ (D, 1 )
652- #endif  //  !(defined(GGML_USE_HIP)
653651static __global__ void flash_attn_combine_results(
654652        const  float   * __restrict__  VKQ_parts,
655653        const  float2  * __restrict__  VKQ_meta,
@@ -692,10 +690,7 @@ static __global__ void flash_attn_combine_results(
692690    float  VKQ_numerator   = 0 .0f ;
693691    float  VKQ_denominator = 0 .0f ;
694692    for  (int  l = 0 ; l < parallel_blocks; ++l) {
695-         const  float  diff = meta[l].x  - kqmax;
696-         float  KQ_max_scale = expf (diff);
697-         const  uint32_t  ftz_mask = 0xFFFFFFFF  * (diff > SOFTMAX_FTZ_THRESHOLD);
698-         *((uint32_t  *) &KQ_max_scale) &= ftz_mask;
693+         const  float  KQ_max_scale = expf (meta[l].x  - kqmax);
699694
700695        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];
701696        VKQ_denominator += KQ_max_scale * meta[l].y ;
@@ -836,11 +831,10 @@ void launch_fattn(
836831        CUDA_CHECK (cudaGetLastError ());
837832    }
838833
839-     int  parallel_blocks = 1 ;
840- 
841834    const  dim3  block_dim (warp_size, nwarps, 1 );
842835    int  max_blocks_per_sm = 1 ; //  Max. number of active blocks limited by occupancy.
843836    CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&max_blocks_per_sm, fattn_kernel, block_dim.x  * block_dim.y  * block_dim.z , nbytes_shared));
837+     int  parallel_blocks = max_blocks_per_sm;
844838
845839    dim3  blocks_num;
846840    if  (stream_k) {
@@ -862,9 +856,6 @@ void launch_fattn(
862856        GGML_ASSERT (K->ne [1 ] % KQ_row_granularity == 0 );
863857        const  int  ntiles_KQ = K->ne [1 ] / KQ_row_granularity; //  Max. number of parallel blocks limited by tensor size.
864858
865-         //  parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
866-         parallel_blocks = std::max ((nsm * max_blocks_per_sm) / ntiles_total, 1 );
867- 
868859        //  parallel_blocks must not be larger than what the tensor size allows:
869860        parallel_blocks = std::min (parallel_blocks, ntiles_KQ);
870861
0 commit comments