Skip to content

Fix Variable Sequence Length Support for Flash Attention Decode #362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: sycl-develop
Choose a base branch
from

Conversation

muhammad-tanvir-1211
Copy link
Collaborator

This PR fixes the variable sequence length support for Flash Attention Decode. It also fixes the causal masking on device code and matches the verification with the one for prefill along with the flops and gbps calculation.

@muhammad-tanvir-1211 muhammad-tanvir-1211 force-pushed the flash_fix_varlen_decode branch from 07bc319 to b5ff4b0 Compare May 9, 2025 10:39
if constexpr (!is_var_len) {
return params;
} else {
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape;
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = logical_problem_shape;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = logical_problem_shape;
auto [num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = select<1, 2, 3, 4, 5, 6, 7>logical_problem_shape;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logical problem shape and problem shape holding a lot of duplicate inputs unnecessarily which takes up register space. The logical problem shape only need to hold shape<int, int, int> for seq_len_qo, seq_len_kv, seq_len_kv_cache the remaining is already provided in the problem shape

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants