1010#include < cutlass/numeric_types.h>
1111#include < cutlass/numeric_conversion.h>
1212#include < cutlass/fast_math.h>
13+ #include < cub/block/block_merge_sort.cuh>
1314
1415#ifndef BLOCK_THREADS
1516#define BLOCK_THREADS 128 // Common CUDA thread block size (multiple of 32)
@@ -23,7 +24,37 @@ namespace FLASH_NAMESPACE {
2324
2425using namespace cute ;
2526
26- // Struct wrapper for dynamic mask application
27+ // Value-Index pair for top-k selection
28+ template <typename ValueType>
29+ struct TopKPair {
30+ ValueType value;
31+ int col_index;
32+
33+ __device__ __forceinline__ TopKPair () : value(ValueType(-INFINITY)), col_index(-1 ) {}
34+ __device__ __forceinline__ TopKPair (ValueType v, int idx) : value(v), col_index(idx) {}
35+
36+ __device__ __forceinline__ bool is_valid () const {
37+ return col_index >= 0 && isfinite (value);
38+ }
39+ };
40+
41+ // Comparison functor for descending sort (greater values first)
42+ template <typename ValueType>
43+ struct DescendingComparator {
44+ __device__ __forceinline__ bool operator ()(const TopKPair<ValueType>& a, const TopKPair<ValueType>& b) const {
45+ // if (isfinite(a.value) && isfinite(b.value)) {
46+ // return a.value > b.value;
47+ // } else if (isfinite(a.value)) {
48+ // return true; // a is valid, b is not
49+ // } else if (isfinite(b.value)) {
50+ // return false; // b is valid, a is not
51+ // } else {
52+ // return a.col_index < b.col_index; // Compare indices if both are invalid
53+ // }
54+ return a.value > b.value ; // Descending order
55+ }
56+ };
57+
2758template <bool Is_causal, int BlockThreads>
2859struct DynamicMask {
2960 const int max_seqlen_k, max_seqlen_q;
@@ -100,112 +131,92 @@ struct DynamicMask {
100131 return ;
101132 }
102133
103- // Apply top-k selection per row if needed
104- #pragma unroll
134+ // Declare shared memory for BlockMergeSort at block scope
135+ using BlockMergeSortT = cub::BlockMergeSort<TopKPair<ElementZeroHold>, BlockThreads, ITEMS_PER_THREAD>;
136+ __shared__ typename BlockMergeSortT::TempStorage temp_storage;
137+ // Process each row with TopK sorting
105138 for (int mi = 0 ; mi < size<0 , 1 >(zero_hold); ++mi) {
106139 const int row_idx_base = row_idx_offset + mi * warp_row_stride;
107- #pragma unroll
108140 for (int i = 0 ; i < size<0 , 0 >(zero_hold); ++i) {
109141 const int row_idx = row_idx_base + i * 8 ;
110- // Skip if out of bounds
111142 if (row_idx >= max_seqlen_q) continue ;
112-
113- // Temporarily mark all active elements as inactive for selection
114- #pragma unroll
115- for (int nj = 0 ; nj < size<1 , 1 >(zero_hold); ++nj) {
116- #pragma unroll
117- for (int j = 0 ; j < size<1 , 0 >(zero_hold); ++j) {
118- auto coord = make_coord (make_coord (i, mi), make_coord (j, nj));
119- if (active_indices (coord)) {
120- active_indices (coord) = false ;
121- }
122- }
123- }
124- __syncthreads ();
125143
126- // Shared memory for reduction
127- __shared__ float s_max_vals[BlockThreads];
128- __shared__ int s_max_indices_nj[BlockThreads];
129- __shared__ int s_max_indices_j[BlockThreads];
144+ // Step 1: Thread-local storage for collecting current row elements
145+ TopKPair<ElementZeroHold> thread_data[ITEMS_PER_THREAD];
130146
131- // Iteratively select top-k elements
132- for (int k = 0 ; k < keep_window_size; ++k) {
133- float thread_max = -FLT_MAX;
134- int thread_max_nj = -1 ;
135- int thread_max_j = -1 ;
136-
137- // Each thread finds its local maximum using the same loop structure
138- #pragma unroll
139- for (int nj = 0 ; nj < size<1 , 1 >(zero_hold); ++nj) {
140- const int col_idx_base = col_idx_offset + nj * 8 ;
141- #pragma unroll
142- for (int j = 0 ; j < size<1 , 0 >(zero_hold); ++j) {
143- const int col_idx = col_idx_base + j;
144- auto coord = make_coord (make_coord (i, mi), make_coord (j, nj));
145-
146- bool valid = (col_idx < max_seqlen_k) && !(Is_causal && col_idx > row_idx);
147- float val = static_cast <float >(zero_hold (coord));
148- if (valid && !active_indices (coord) && !isinf (val) && val > thread_max) {
149- thread_max = val;
150- thread_max_nj = nj;
151- thread_max_j = j;
152- }
153- }
154- }
155-
156- // Store thread-local maximum
157- s_max_vals[tid] = thread_max;
158- s_max_indices_nj[tid] = thread_max_nj;
159- s_max_indices_j[tid] = thread_max_j;
160- __syncthreads ();
147+ // Initialize all elements as invalid
148+ for (int item = 0 ; item < ITEMS_PER_THREAD; ++item) {
149+ thread_data[item] = TopKPair<ElementZeroHold>();
150+ }
151+
152+ // Collect data from current row
153+ for (int item = 0 ; item < ITEMS_PER_THREAD; ++item) {
154+ int global_idx = tid * ITEMS_PER_THREAD + item;
161155
162- // Parallel reduction to find global maximum
163- for (int stride = BlockThreads / 2 ; stride > 0 ; stride >>= 1 ) {
164- if (tid < stride) {
165- if (s_max_vals[tid] < s_max_vals[tid + stride]) {
166- s_max_vals[tid] = s_max_vals[tid + stride];
167- s_max_indices_nj[tid] = s_max_indices_nj[tid + stride];
168- s_max_indices_j[tid] = s_max_indices_j[tid + stride];
156+ if (global_idx < max_seqlen_k) {
157+ // Find element with column index = global_idx in current row
158+ for (int nj = 0 ; nj < size<1 , 1 >(zero_hold); ++nj) {
159+ const int col_idx_base = col_idx_offset + nj * 8 ;
160+ for (int j = 0 ; j < size<1 , 0 >(zero_hold); ++j) {
161+ const int col_idx = col_idx_base + j;
162+ if (col_idx == global_idx) {
163+ auto coord = make_coord (make_coord (i, mi), make_coord (j, nj));
164+
165+ // If active, collect its value and index
166+ if (active_indices (coord)) {
167+ ElementZeroHold val = zero_hold (coord);
168+ thread_data[item] = TopKPair<ElementZeroHold>(val, col_idx);
169+ }
170+ break ; // Found the element, no need to continue
171+ }
169172 }
170173 }
171- __syncthreads ();
172- }
173-
174- // Mark the selected index as active
175- if (tid == 0 && s_max_indices_nj[0 ] >= 0 && s_max_indices_j[0 ] >= 0 ) {
176- auto coord = make_coord (make_coord (i, mi), make_coord (s_max_indices_j[0 ], s_max_indices_nj[0 ]));
177- active_indices (coord) = true ;
178- }
179- __syncthreads ();
180-
181- // Early exit if no more valid elements
182- if (s_max_vals[0 ] == -FLT_MAX) {
183- break ;
184174 }
185175 }
186176
187- // Clear non-selected values using the same loop structure
188- #pragma unroll
177+ // Step 2: Block-wide collaborative sorting with explicit comparator
178+ DescendingComparator<ElementZeroHold> comp;
179+ BlockMergeSortT (temp_storage).Sort (thread_data, comp);
180+ __syncthreads (); // Ensure sorting is complete
181+
182+ // Step 3: Update active_indices - keep only topk
183+ // Traverse each coordinate and check if its col_idx is in topk
189184 for (int nj = 0 ; nj < size<1 , 1 >(zero_hold); ++nj) {
190- # pragma unroll
185+ const int col_idx_base = col_idx_offset + nj * 8 ;
191186 for (int j = 0 ; j < size<1 , 0 >(zero_hold); ++j) {
187+ const int col_idx = col_idx_base + j;
192188 auto coord = make_coord (make_coord (i, mi), make_coord (j, nj));
193- if (!active_indices (coord)) {
194- zero_hold (coord) = ElementZeroHold (-INFINITY);
189+
190+ // If current position is active, check if it's in topk
191+ if (active_indices (coord)) {
192+ // Check if this element is in thread's own topk data
193+ bool is_in_topk = false ;
194+
195+ for (int item = 0 ; item < ITEMS_PER_THREAD; ++item) {
196+ // Global position in sorted order
197+ int global_pos = tid * ITEMS_PER_THREAD + item;
198+
199+ // Only elements with global_pos < keep_window_size are topk
200+ if (global_pos < keep_window_size &&
201+ thread_data[item].col_index == col_idx) {
202+ is_in_topk = true ;
203+ break ;
204+ }
205+ }
206+
207+ // If not in topk, set as inactive
208+ if (!is_in_topk) {
209+ active_indices (coord) = false ;
210+ }
195211 }
196212 }
197213 }
198- __syncthreads ();
199-
214+ __syncthreads (); // Ensure row processing is complete
200215 }
201216 }
202217 }
203218
204-
205- template <
206- bool Causal_mask=false , bool Is_even_MN=true ,
207- typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2
208- >
219+ template <bool Causal_mask=false , bool Is_even_MN=true , typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2>
209220 __forceinline__ __device__ void apply_mask (
210221 Tensor<Engine0, Layout0> &tensor_, // acc_s (attention scores, 3D)
211222 Tensor<Engine1, Layout1> &tZeroHold, // Zero-hold states (3D)
0 commit comments