Skip to content

Commit dc90353

Browse files
authored
Merge pull request #17 from Evanwu1125/main
Update mask.h
2 parents 19f160f + 7ce4a4c commit dc90353

File tree

1 file changed

+95
-84
lines changed

1 file changed

+95
-84
lines changed

csrc/src/mask.h

Lines changed: 95 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <cutlass/numeric_types.h>
1111
#include <cutlass/numeric_conversion.h>
1212
#include <cutlass/fast_math.h>
13+
#include <cub/block/block_merge_sort.cuh>
1314

1415
#ifndef BLOCK_THREADS
1516
#define BLOCK_THREADS 128 // Common CUDA thread block size (multiple of 32)
@@ -23,7 +24,37 @@ namespace FLASH_NAMESPACE {
2324

2425
using namespace cute;
2526

26-
// Struct wrapper for dynamic mask application
27+
// Value-Index pair for top-k selection
28+
template<typename ValueType>
29+
struct TopKPair {
30+
ValueType value;
31+
int col_index;
32+
33+
__device__ __forceinline__ TopKPair() : value(ValueType(-INFINITY)), col_index(-1) {}
34+
__device__ __forceinline__ TopKPair(ValueType v, int idx) : value(v), col_index(idx) {}
35+
36+
__device__ __forceinline__ bool is_valid() const {
37+
return col_index >= 0 && isfinite(value);
38+
}
39+
};
40+
41+
// Comparison functor for descending sort (greater values first)
42+
template<typename ValueType>
43+
struct DescendingComparator {
44+
__device__ __forceinline__ bool operator()(const TopKPair<ValueType>& a, const TopKPair<ValueType>& b) const {
45+
// if (isfinite(a.value) && isfinite(b.value)) {
46+
// return a.value > b.value;
47+
// } else if (isfinite(a.value)) {
48+
// return true; // a is valid, b is not
49+
// } else if (isfinite(b.value)) {
50+
// return false; // b is valid, a is not
51+
// } else {
52+
// return a.col_index < b.col_index; // Compare indices if both are invalid
53+
// }
54+
return a.value > b.value; // Descending order
55+
}
56+
};
57+
2758
template <bool Is_causal, int BlockThreads>
2859
struct DynamicMask {
2960
const int max_seqlen_k, max_seqlen_q;
@@ -100,112 +131,92 @@ struct DynamicMask {
100131
return;
101132
}
102133

103-
// Apply top-k selection per row if needed
104-
#pragma unroll
134+
// Declare shared memory for BlockMergeSort at block scope
135+
using BlockMergeSortT = cub::BlockMergeSort<TopKPair<ElementZeroHold>, BlockThreads, ITEMS_PER_THREAD>;
136+
__shared__ typename BlockMergeSortT::TempStorage temp_storage;
137+
// Process each row with TopK sorting
105138
for (int mi = 0; mi < size<0, 1>(zero_hold); ++mi) {
106139
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
107-
#pragma unroll
108140
for (int i = 0; i < size<0, 0>(zero_hold); ++i) {
109141
const int row_idx = row_idx_base + i * 8;
110-
// Skip if out of bounds
111142
if (row_idx >= max_seqlen_q) continue;
112-
113-
// Temporarily mark all active elements as inactive for selection
114-
#pragma unroll
115-
for (int nj = 0; nj < size<1, 1>(zero_hold); ++nj) {
116-
#pragma unroll
117-
for (int j = 0; j < size<1, 0>(zero_hold); ++j) {
118-
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
119-
if (active_indices(coord)) {
120-
active_indices(coord) = false;
121-
}
122-
}
123-
}
124-
__syncthreads();
125143

126-
// Shared memory for reduction
127-
__shared__ float s_max_vals[BlockThreads];
128-
__shared__ int s_max_indices_nj[BlockThreads];
129-
__shared__ int s_max_indices_j[BlockThreads];
144+
// Step 1: Thread-local storage for collecting current row elements
145+
TopKPair<ElementZeroHold> thread_data[ITEMS_PER_THREAD];
130146

131-
// Iteratively select top-k elements
132-
for (int k = 0; k < keep_window_size; ++k) {
133-
float thread_max = -FLT_MAX;
134-
int thread_max_nj = -1;
135-
int thread_max_j = -1;
136-
137-
// Each thread finds its local maximum using the same loop structure
138-
#pragma unroll
139-
for (int nj = 0; nj < size<1, 1>(zero_hold); ++nj) {
140-
const int col_idx_base = col_idx_offset + nj * 8;
141-
#pragma unroll
142-
for (int j = 0; j < size<1, 0>(zero_hold); ++j) {
143-
const int col_idx = col_idx_base + j;
144-
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
145-
146-
bool valid = (col_idx < max_seqlen_k) && !(Is_causal && col_idx > row_idx);
147-
float val = static_cast<float>(zero_hold(coord));
148-
if (valid && !active_indices(coord) && !isinf(val) && val > thread_max) {
149-
thread_max = val;
150-
thread_max_nj = nj;
151-
thread_max_j = j;
152-
}
153-
}
154-
}
155-
156-
// Store thread-local maximum
157-
s_max_vals[tid] = thread_max;
158-
s_max_indices_nj[tid] = thread_max_nj;
159-
s_max_indices_j[tid] = thread_max_j;
160-
__syncthreads();
147+
// Initialize all elements as invalid
148+
for (int item = 0; item < ITEMS_PER_THREAD; ++item) {
149+
thread_data[item] = TopKPair<ElementZeroHold>();
150+
}
151+
152+
// Collect data from current row
153+
for (int item = 0; item < ITEMS_PER_THREAD; ++item) {
154+
int global_idx = tid * ITEMS_PER_THREAD + item;
161155

162-
// Parallel reduction to find global maximum
163-
for (int stride = BlockThreads / 2; stride > 0; stride >>= 1) {
164-
if (tid < stride) {
165-
if (s_max_vals[tid] < s_max_vals[tid + stride]) {
166-
s_max_vals[tid] = s_max_vals[tid + stride];
167-
s_max_indices_nj[tid] = s_max_indices_nj[tid + stride];
168-
s_max_indices_j[tid] = s_max_indices_j[tid + stride];
156+
if (global_idx < max_seqlen_k) {
157+
// Find element with column index = global_idx in current row
158+
for (int nj = 0; nj < size<1, 1>(zero_hold); ++nj) {
159+
const int col_idx_base = col_idx_offset + nj * 8;
160+
for (int j = 0; j < size<1, 0>(zero_hold); ++j) {
161+
const int col_idx = col_idx_base + j;
162+
if (col_idx == global_idx) {
163+
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
164+
165+
// If active, collect its value and index
166+
if (active_indices(coord)) {
167+
ElementZeroHold val = zero_hold(coord);
168+
thread_data[item] = TopKPair<ElementZeroHold>(val, col_idx);
169+
}
170+
break; // Found the element, no need to continue
171+
}
169172
}
170173
}
171-
__syncthreads();
172-
}
173-
174-
// Mark the selected index as active
175-
if (tid == 0 && s_max_indices_nj[0] >= 0 && s_max_indices_j[0] >= 0) {
176-
auto coord = make_coord(make_coord(i, mi), make_coord(s_max_indices_j[0], s_max_indices_nj[0]));
177-
active_indices(coord) = true;
178-
}
179-
__syncthreads();
180-
181-
// Early exit if no more valid elements
182-
if (s_max_vals[0] == -FLT_MAX) {
183-
break;
184174
}
185175
}
186176

187-
// Clear non-selected values using the same loop structure
188-
#pragma unroll
177+
// Step 2: Block-wide collaborative sorting with explicit comparator
178+
DescendingComparator<ElementZeroHold> comp;
179+
BlockMergeSortT(temp_storage).Sort(thread_data, comp);
180+
__syncthreads(); // Ensure sorting is complete
181+
182+
// Step 3: Update active_indices - keep only topk
183+
// Traverse each coordinate and check if its col_idx is in topk
189184
for (int nj = 0; nj < size<1, 1>(zero_hold); ++nj) {
190-
#pragma unroll
185+
const int col_idx_base = col_idx_offset + nj * 8;
191186
for (int j = 0; j < size<1, 0>(zero_hold); ++j) {
187+
const int col_idx = col_idx_base + j;
192188
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
193-
if (!active_indices(coord)) {
194-
zero_hold(coord) = ElementZeroHold(-INFINITY);
189+
190+
// If current position is active, check if it's in topk
191+
if (active_indices(coord)) {
192+
// Check if this element is in thread's own topk data
193+
bool is_in_topk = false;
194+
195+
for (int item = 0; item < ITEMS_PER_THREAD; ++item) {
196+
// Global position in sorted order
197+
int global_pos = tid * ITEMS_PER_THREAD + item;
198+
199+
// Only elements with global_pos < keep_window_size are topk
200+
if (global_pos < keep_window_size &&
201+
thread_data[item].col_index == col_idx) {
202+
is_in_topk = true;
203+
break;
204+
}
205+
}
206+
207+
// If not in topk, set as inactive
208+
if (!is_in_topk) {
209+
active_indices(coord) = false;
210+
}
195211
}
196212
}
197213
}
198-
__syncthreads();
199-
214+
__syncthreads(); // Ensure row processing is complete
200215
}
201216
}
202217
}
203218

204-
205-
template <
206-
bool Causal_mask=false, bool Is_even_MN=true,
207-
typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2
208-
>
219+
template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2>
209220
__forceinline__ __device__ void apply_mask(
210221
Tensor<Engine0, Layout0> &tensor_, // acc_s (attention scores, 3D)
211222
Tensor<Engine1, Layout1> &tZeroHold, // Zero-hold states (3D)

0 commit comments

Comments
 (0)