Skip to content

Commit

Permalink
perf: fix the iteration bound of SWA in FA2 prefill template (#714)
Browse files Browse the repository at this point in the history
We forgot to divide the packed row index by group_size when computing
the sliding window iteration bound, making it larger than its actual
value, and slows down the execution.

Thank @Ying1123 for spotting this bug.
  • Loading branch information
yzh119 authored Jan 4, 2025
1 parent 0f80329 commit 989dbfa
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV
16 * NUM_WARPS_KV * NUM_MMA_KV);

const uint32_t window_iteration =
ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * num_rows_per_cta,
ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * num_rows_per_cta / group_size,
qo_len + window_left + chunk_start),
(16 * NUM_WARPS_KV * NUM_MMA_KV));

Expand Down Expand Up @@ -1652,7 +1652,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag
16 * NUM_WARPS_KV * NUM_MMA_KV);

const uint32_t window_iteration =
ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * num_rows_per_cta,
ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * num_rows_per_cta / group_size,
qo_len + window_left + chunk_start),
(16 * NUM_WARPS_KV * NUM_MMA_KV));

Expand Down Expand Up @@ -1980,7 +1980,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag
16 * NUM_WARPS_KV * NUM_MMA_KV);

const uint32_t window_iteration =
ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * num_rows_per_cta,
ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * num_rows_per_cta / group_size,
qo_len + window_left + chunk_start),
(16 * NUM_WARPS_KV * NUM_MMA_KV));

Expand Down

0 comments on commit 989dbfa

Please sign in to comment.