Skip to content

Comments

fix(sparse-attn): correct block_size computation in backward kernel#1

Open
Chuge0335 wants to merge 1 commit intomainfrom
fix-triton-backward
Open

fix(sparse-attn): correct block_size computation in backward kernel#1
Chuge0335 wants to merge 1 commit intomainfrom
fix-triton-backward

Conversation

@Chuge0335
Copy link
Owner

fix(sparse-attn): correct block_size masking in backward kernel

The backward kernel used an incorrect block_size when applying the mask, causing padding positions to be treated as valid tokens. As a result, p = tl.math.exp2(qk - m) was computed on invalid entries, leading to Inf and NaN values during gradient accumulation (especially in dQ).

This patch fixes the block_size computation for split blocks so that the mask correctly excludes padded regions in all cases.

fix(sparse-attn): correct block_size masking in backward kernel

The backward kernel used an incorrect block_size when applying the mask,
causing padding positions to be treated as valid tokens. As a result,
p = tl.math.exp2(qk - m) was computed on invalid entries, leading to Inf
and NaN values during gradient accumulation (especially in dQ).

This patch fixes the block_size computation for split blocks so that the
mask correctly excludes padded regions in all cases.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant