@@ -26,30 +26,7 @@ limitations under the License. */
2626namespace phi {
2727namespace sparse {
2828
29- #define PRIVATE_CASE_VISIT_ATTN_SOFTMAX (NAME, size, HINT, ...) \
30- case size: { \
31- constexpr int HINT = size; \
32- __VA_ARGS__ (); \
33- break ; \
34- }
35-
36- #define VISIT_ATTN_SFOTMAX (SIZE, NAME, ...) \
37- [&] { \
38- const auto & __size__ = SIZE; \
39- switch (__size__) { \
40- PRIVATE_CASE_VISIT_ATTN_SOFTMAX (NAME, 1 , KBufferSize, __VA_ARGS__) \
41- PRIVATE_CASE_VISIT_ATTN_SOFTMAX (NAME, 2 , KBufferSize, __VA_ARGS__) \
42- PRIVATE_CASE_VISIT_ATTN_SOFTMAX (NAME, 3 , KBufferSize, __VA_ARGS__) \
43- PRIVATE_CASE_VISIT_ATTN_SOFTMAX (NAME, 4 , KBufferSize, __VA_ARGS__) \
44- PRIVATE_CASE_VISIT_ATTN_SOFTMAX (NAME, 8 , KBufferSize, __VA_ARGS__) \
45- PRIVATE_CASE_VISIT_ATTN_SOFTMAX (NAME, 12 , KBufferSize, __VA_ARGS__) \
46- PRIVATE_CASE_VISIT_ATTN_SOFTMAX (NAME, 16 , KBufferSize, __VA_ARGS__) \
47- default : \
48- PD_THROW (" function " #NAME " is not implemented for columns>512 " ); \
49- } \
50- }()
51-
52- template <typename T, int BufferSize>
29+ template <typename T>
5330__global__ void AttnSoftmaxGpuKernel (const int64_t * x_crows,
5431 const int64_t * x_cols,
5532 const T* x_values,
@@ -58,7 +35,6 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
5835 T* out_values,
5936 int M,
6037 int total_row_num,
61- float scale,
6238 int num_heads,
6339 int batch_nnz) {
6440 // out = exp(x-x_max) / sum(exp(x-x_max))
@@ -72,17 +48,10 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
7248 int row_nnz = static_cast <int >(x_crows[crow_idx + 1 ] - x_crows[crow_idx]);
7349 if (row_nnz == 0 ) return ;
7450
75- T buffer[BufferSize] = {0 };
76- int kIteration = (row_nnz + WARP_SIZE - 1 ) / WARP_SIZE;
77-
7851 T max_val = -std::numeric_limits<T>::infinity ();
79- for (int i = 0 ; i < kIteration ; ++i ) {
52+ for (int idx = threadIdx . x ; idx < row_nnz; idx += blockDim . x ) {
8053 bool mask = false ;
81- int idx = threadIdx .x + i * WARP_SIZE;
82- if (idx >= row_nnz) break ;
83-
8454 int col_idx = static_cast <int >(x_cols[row_first + idx]);
85-
8655 if (kp_mask != nullptr &&
8756 kp_mask[(cur_batch / num_heads) * M + col_idx] == 0 ) {
8857 mask = true ;
@@ -92,37 +61,30 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
9261 }
9362
9463 if (!mask) {
95- buffer[i] = x_values[row_first + idx] / scale ;
96- if (buffer[i] > max_val) {
97- max_val = buffer[i] ;
64+ T val = x_values[row_first + idx];
65+ if (val > max_val) {
66+ max_val = val ;
9867 }
68+ out_values[row_first + idx] = val;
69+ } else {
70+ // Note corner case: when all elements of the row are masked, result
71+ // may be wrong because of exp('-inf' - '-inf'), just ignore now.
72+ out_values[row_first + idx] = -std::numeric_limits<T>::infinity ();
9973 }
10074 }
10175 T row_max_val = phi::funcs::warpReduceMax<T>(max_val, 0xFFFFFFFF );
10276
103- auto functor = phi::funcs::CudaExpFunctor<T>();
10477 T exp_sum = 0 ;
105- for (int i = 0 ; i < kIteration ; ++i) {
106- int idx = threadIdx .x + i * WARP_SIZE;
107- if (idx >= row_nnz) break ;
108-
109- if (buffer[i]) {
110- T exp = functor (buffer[i] - row_max_val);
111- exp_sum += exp;
112- buffer[i] = exp;
113- }
78+ for (int idx = threadIdx .x ; idx < row_nnz; idx += blockDim .x ) {
79+ auto functor = phi::funcs::CudaExpFunctor<T>();
80+ T exp = functor (out_values[row_first + idx] - row_max_val);
81+ exp_sum += exp;
82+ out_values[row_first + idx] = exp;
11483 }
11584 T row_exp_sum = phi::funcs::warpReduceSum<T>(exp_sum, 0xFFFFFFFF );
11685
117- for (int i = 0 ; i < kIteration ; ++i) {
118- int idx = threadIdx .x + i * WARP_SIZE;
119- if (idx >= row_nnz) break ;
120-
121- if (buffer[i]) {
122- out_values[row_first + idx] = buffer[i] / row_exp_sum;
123- } else {
124- out_values[row_first + idx] = static_cast <T>(0 );
125- }
86+ for (int idx = threadIdx .x ; idx < row_nnz; idx += blockDim .x ) {
87+ out_values[row_first + idx] = out_values[row_first + idx] / row_exp_sum;
12688 }
12789}
12890
@@ -219,49 +181,36 @@ void FusedAttentionCsrKernel(
219181 " shape of 'attn_mask' must be [seq_len, seq_len]" ));
220182 }
221183
222- /* Step1: SDD Matmul, reuse */
184+ /* Step1: SDD Matmul, reuse matmul */
223185 SparseCsrTensor sdd_result;
224186 EmptyLikeCsrKernel<T, Context>(dev_ctx, sparse_mask, &sdd_result);
225187 auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
226188 sparse_blas.SDDMM (false ,
227189 true ,
228- static_cast <T>(1 ),
190+ static_cast <T>(1 / std::sqrt (N) ),
229191 query,
230192 key,
231193 static_cast <T>(0 ),
232194 &sdd_result);
233195
234- /* Step2: Softmax with kp_mask/attn_mask, manualy not reuse */
235196 EmptyLikeCsrKernel<T, Context>(dev_ctx, sdd_result, softmax);
236197
237- int buffer_size;
238- if (M < 128 ) {
239- buffer_size = (M + 32 - 1 ) / 32 ;
240- } else {
241- buffer_size = ((M + 128 - 1 ) / 128 ) * 4 ;
242- }
243-
244- dim3 grid ((total_row_num + 3 ) / 4 );
245- dim3 block (WARP_SIZE, 4 );
198+ dim3 grid ((total_row_num + 7 ) / 8 );
199+ dim3 block (WARP_SIZE, 8 );
246200
247201 int batch_nnz = sdd_result.nnz () / batch_num;
202+ AttnSoftmaxGpuKernel<T><<<grid, block, 0 , dev_ctx.stream()>>> (
203+ sdd_result.non_zero_crows ().data <int64_t >(),
204+ sdd_result.non_zero_cols ().data <int64_t >(),
205+ sdd_result.non_zero_elements ().data <T>(),
206+ kp_mask_ptr ? kp_mask_ptr->data <T>() : nullptr ,
207+ attn_mask_ptr ? attn_mask_ptr->data <T>() : nullptr ,
208+ softmax->mutable_non_zero_elements ()->data <T>(),
209+ M,
210+ total_row_num,
211+ q_dim[1 ],
212+ batch_nnz);
248213
249- VISIT_ATTN_SFOTMAX (buffer_size, " AttnSoftmaxGpuKernel" , [&] {
250- AttnSoftmaxGpuKernel<T, KBufferSize><<<grid, block, 0 , dev_ctx.stream()>>> (
251- sdd_result.non_zero_crows ().data <int64_t >(),
252- sdd_result.non_zero_cols ().data <int64_t >(),
253- sdd_result.non_zero_elements ().data <T>(),
254- kp_mask_ptr ? kp_mask_ptr->data <T>() : nullptr ,
255- attn_mask_ptr ? attn_mask_ptr->data <T>() : nullptr ,
256- softmax->mutable_non_zero_elements ()->data <T>(),
257- M,
258- total_row_num,
259- std::sqrt (N),
260- q_dim[1 ],
261- batch_nnz);
262- });
263-
264- /* Step3: DSD Matmul, reuse */
265214 softmax->set_dims (phi::make_ddim ({q_dim[0 ], q_dim[1 ], q_dim[2 ], q_dim[2 ]}));
266215 MatmulCsrDenseKernel<T, Context>(dev_ctx, *softmax, value, out);
267216#else
0 commit comments