@@ -41,7 +41,8 @@ namespace attention {
4141#define SHFL_SYNC (var, src_lane ) __shfl_sync(uint32_t (-1 ), var, src_lane)
4242
4343// Q*K^T operation.
44- template <int NUM_THREADS_PER_ROUNDS, int NUM_THREADS_PER_X, typename VecT, int N>
44+ template <int NUM_THREADS_PER_ROUNDS, int NUM_THREADS_PER_X, typename VecT,
45+ int N>
4546inline __device__ float qk_dot_ (const VecT (&q)[N], const VecT (&k)[N]) {
4647 using A_vec = typename common::FloatVecTypeTrait<VecT>::Type;
4748 // Compute the parallel products for Q*K^T (treat vector lanes separately).
@@ -57,12 +58,13 @@ inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) {
5758
5859 // Finalize the reduction across lanes.
5960 float qk = sum_vect (qk_vec);
60- #pragma unroll
61- for (int mask = (WARP_SIZE >> 1 ); mask >= NUM_THREADS_PER_ROUNDS; mask >>= 1 ) {
61+ #pragma unroll
62+ for (int mask = (WARP_SIZE >> 1 ); mask >= NUM_THREADS_PER_ROUNDS;
63+ mask >>= 1 ) {
6264 qk += SHFL_XOR_SYNC (qk, mask);
6365 }
6466
65- #pragma unroll
67+ #pragma unroll
6668 for (int mask = (NUM_THREADS_PER_X >> 1 ); mask > 0 ; mask >>= 1 ) {
6769 qk += SHFL_XOR_SYNC (qk, mask);
6870 }
@@ -86,7 +88,8 @@ inline __device__ float block_max(float* red_smem, float max) {
8688// for each warp, the 1st out of NUM_THREADS_PER_TOKEN thread already has the
8789// max value among every NUM_THREADS_PER_TOKEN threads.
8890#pragma unroll
89- for (int mask = (NUM_THREADS_PER_ROUNDS >> 1 ); mask >= NUM_THREADS_PER_X; mask >>= 1 ) {
91+ for (int mask = (NUM_THREADS_PER_ROUNDS >> 1 ); mask >= NUM_THREADS_PER_X;
92+ mask >>= 1 ) {
9093 max = fmaxf (max, SHFL_XOR_SYNC (max, mask));
9194 }
9295
0 commit comments