Skip to content

Commit c32311b

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 9f3cc50 commit c32311b

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

extensions/csrc/kernel/cuda/attention/attention_utils.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ namespace attention {
4141
#define SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
4242

4343
// Q*K^T operation.
44-
template <int NUM_THREADS_PER_ROUNDS, int NUM_THREADS_PER_X, typename VecT, int N>
44+
template <int NUM_THREADS_PER_ROUNDS, int NUM_THREADS_PER_X, typename VecT,
45+
int N>
4546
inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) {
4647
using A_vec = typename common::FloatVecTypeTrait<VecT>::Type;
4748
// Compute the parallel products for Q*K^T (treat vector lanes separately).
@@ -57,12 +58,13 @@ inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) {
5758

5859
// Finalize the reduction across lanes.
5960
float qk = sum_vect(qk_vec);
60-
#pragma unroll
61-
for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_ROUNDS; mask >>= 1) {
61+
#pragma unroll
62+
for (int mask = (WARP_SIZE >> 1); mask >= NUM_THREADS_PER_ROUNDS;
63+
mask >>= 1) {
6264
qk += SHFL_XOR_SYNC(qk, mask);
6365
}
6466

65-
#pragma unroll
67+
#pragma unroll
6668
for (int mask = (NUM_THREADS_PER_X >> 1); mask > 0; mask >>= 1) {
6769
qk += SHFL_XOR_SYNC(qk, mask);
6870
}
@@ -86,7 +88,8 @@ inline __device__ float block_max(float* red_smem, float max) {
8688
// for each warp, the 1st out of NUM_THREADS_PER_TOKEN thread already has the
8789
// max value among every NUM_THREADS_PER_TOKEN threads.
8890
#pragma unroll
89-
for (int mask = (NUM_THREADS_PER_ROUNDS >> 1); mask >= NUM_THREADS_PER_X; mask >>= 1) {
91+
for (int mask = (NUM_THREADS_PER_ROUNDS >> 1); mask >= NUM_THREADS_PER_X;
92+
mask >>= 1) {
9093
max = fmaxf(max, SHFL_XOR_SYNC(max, mask));
9194
}
9295

tests/test_infer/test_ops/triton/kernel_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def mock_alloc_block_table_and_kvcache_v3(
193193

194194
return block_tables
195195

196+
196197
def mock_alloc_block_table_and_kvcache_vllm(
197198
k: torch.Tensor,
198199
v: torch.Tensor,
@@ -293,6 +294,7 @@ def generate_caches_and_block_tables_v2(
293294
)
294295
return k_cache, v_cache, block_tables
295296

297+
296298
def generate_caches_and_block_tables_v3(
297299
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
298300
) -> Tuple[torch.Tensor, ...]:

0 commit comments

Comments
 (0)