-
Notifications
You must be signed in to change notification settings - Fork 39
Backward Kernel: Dynamic Mask Skip, Race Fixes, and 5–6× Speedup vs SDPA (Fixes #132) #145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Reduces redundant mask copying and clearing operations by moving mask checks earlier in the loop and eliminating unnecessary register allocations for fully masked blocks. Consolidates async fence calls to reduce synchronization overhead and removes redundant clear operations on accumulator fragments when blocks are skipped due to masking.
Adds early mask checking to determine if any threads are active before performing expensive computations. This optimization prevents unnecessary work when entire blocks are masked out, improving performance by: - Moving mask evaluation earlier in the computation pipeline - Conditionally executing bias loading and gemm operations - Tracking active state across iterations to avoid redundant work - Reducing memory transfers and computation overhead for masked regions The change maintains correctness while significantly reducing wasted cycles in attention backward pass kernels when dealing with padded sequences.
Tracks whether bias parameter was provided during forward pass and only returns bias gradient during backward pass when bias was originally given. Prevents unnecessary computation and memory allocation for bias gradients when no bias is used in the attention mechanism.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…lash-dmattn into support-integration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements a comprehensive optimization of the Flash Dynamic Mask Attention backward kernel, focusing on dynamic mask skip functionality and performance improvements. The changes eliminate race conditions, add early termination for fully masked blocks, and achieve 4.9×–6.2× speedups over PyTorch SDPA backward.
- Dynamic mask skip with block-level OR reduction to avoid computation on fully masked tiles
- Race condition fixes through proper shared memory aliasing and synchronization barrier optimization
- Performance optimization by removing redundant barriers and adding conditional execution paths
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| flash_dmattn/flash_dmattn_interface.py | Adds conditional dbias return logic to avoid unnecessary gradient computation when bias is None |
| csrc/src/flash_bwd_kernel.h | Implements dynamic mask skip, fixes race conditions, and optimizes synchronization in backward CUDA kernel |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| return_dbias = True | ||
| if bias is None: | ||
| return_dbias = False |
Copilot
AI
Sep 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic can be simplified to a single line: return_dbias = bias is not None. This makes the intent clearer and reduces code verbosity.
| return_dbias = True | |
| if bias is None: | |
| return_dbias = False | |
| return_dbias = bias is not None |
csrc/src/flash_bwd_kernel.h
Outdated
| Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask); | ||
| bool any_active_local = false; | ||
| bool any_active_local_next = false; // to be updated later for next iteration | ||
| #pragma unroll | ||
| for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } | ||
| bool any_active = __syncthreads_or(any_active_local); | ||
| bool any_active_next = any_active; // to be updated later for next iteration |
Copilot
AI
Sep 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The variable any_active_next is initialized to any_active but then immediately reassigned later. Consider initializing it to false or adding a comment explaining why it needs this initial value.
| Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask); | |
| bool any_active_local = false; | |
| bool any_active_local_next = false; // to be updated later for next iteration | |
| #pragma unroll | |
| for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } | |
| bool any_active = __syncthreads_or(any_active_local); | |
| bool any_active_next = any_active; // to be updated later for next iteration | |
| bool any_active_next = false; // to be updated later for next iteration |
| // Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); | ||
| // Convert dS from fp32 to fp16 |
Copilot
AI
Sep 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The commented-out line above this shows the original approach using dS.data(). Consider removing the commented line or adding a comment explaining why the direct acc_dp conversion approach was chosen for clarity.
| // Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); | |
| // Convert dS from fp32 to fp16 | |
| // Previous approach: Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); | |
| // We now convert acc_dp directly to the target type for efficiency and correctness, | |
| // avoiding potential issues with pointer aliasing and layout mismatches. |
Initializes any_active_next to false instead of any_active to prevent potential issues with unintended carry-over of active state between iterations in the kernel loops. Changes affect both forward and backward kernel implementations to ensure consistent behavior across the codebase.
Description
This PR resolves issue #132 by completing the dynamic-mask aware backward kernel refactor for Flash Dynamic Mask Attention.
Key goals achieved:
dS).__syncthreads()while preserving required barriers around alias reuse and cooperative GEMMs.any_active/any_active_nextare CTA-uniform (__syncthreads_or) so that shared loads / stores are never partially executed.Type of Change
Related Issues
Changes Made
Code Changes
Backward CUDA kernel (
compute_dq_dk_dv_1colblock):any_active/any_active_next.acc_dpbuffer directly (removed cross-scope dependency).__syncthreads_orcp_async_wait<1>barrier in V-in-regs pathany_active).Other:
dSto avoid compile error.Documentation
Testing
No Python API surface change. No dependency changes.
Testing
Performed:
CUDA_LAUNCH_BLOCKING=1and cuda-memcheck (spot tests)Planned / Deferred:
Test Configuration
.cpython-312build)Performance Impact
Benchmark Protocol
Command:
Mode: Local window attention sweeping window size (W) while keeping Q=K=16384, B=1, Hq=2, Hkv=1, D=32, causal (C).
Results (Averaged over 3 runs)
Observations:
Before (Legacy)
No earlier native dynamic mask backward existed in-tree (baseline = PyTorch SDPA). If you want a “Before” column for internal review later, we can add raw times from pre-synchronization-slimming commit.
Breaking Changes
None.
Checklist
CUDA-specific
Additional Notes
Future work (tracked or to be filed):
Safety rationale for removed barriers:
__syncthreads()was immediately followed by a__syncthreads_or, or only guarded accesses to per-thread exclusive SMEM slices postcp_async_wait.