Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Description

This PR resolves issue #132 by completing the dynamic-mask aware backward kernel refactor for Flash Dynamic Mask Attention.
Key goals achieved:

  1. Correctness: Eliminated stale shared-memory reuse races under aliased regions (sMask/sP and sBias/sdS) and removed scope-related undefined identifier errors (e.g., dS).
  2. Dynamic Skip (Early Exit): Integrated block-level mask OR reduction to fully skip zero-score tiles (avoids 5 GEMMs + softmax path when a tile is fully masked).
  3. Synchronization Slimming: Removed three redundant __syncthreads() while preserving required barriers around alias reuse and cooperative GEMMs.
  4. Predicate Uniformity: Ensured any_active / any_active_next are CTA-uniform (__syncthreads_or) so that shared loads / stores are never partially executed.
  5. Performance: Achieved 4.9×–6.2× speedups over PyTorch SDPA backward on RTX 4090 across window sizes (local attention window sweep).

Type of Change

  • Performance optimization
  • CUDA kernel improvement
  • Bug fix (race + scope)
  • New feature
  • Breaking change
  • Documentation update
  • Code refactoring (partial – localized kernel restructuring)

Related Issues

Changes Made

Code Changes

Backward CUDA kernel (compute_dq_dk_dv_1colblock):

  • Added mask-first load + CTA-wide OR reduction to decide early skip.
  • Replaced per-thread predicates with CTA-uniform any_active / any_active_next.
  • Moved dS write path to reuse acc_dp buffer directly (removed cross-scope dependency).
  • Removed 3 redundant barriers:
    • Prologue barrier before first __syncthreads_or
    • Post-cp_async_wait<1> barrier in V-in-regs path
    • Mask-next pre-OR barrier inside main loop
  • Kept essential barriers before GEMM reading fully-populated tiles or before alias reuse of shared memory.
  • Ensured safe alias transitions: (sdS ↔ sBias), (sP ↔ sMask), sdO reuse for dV GEMM.
  • Converted fragmented dS row/col view to direct accumulator reinterpretation to lower register live range.
  • Added conditional lazy loads (Bias only when any_active).

Other:

  • Cleaned variable scope for dS to avoid compile error.
  • Consistent cp.async sequence: copy → fence → wait → (optional) barrier → OR → conditional dependent copies.

Documentation

  • (Pending) Will add a short design note summarizing backward early-skip math & synchronization invariants.
  • (Optional) Update README performance matrix (not part of this PR yet).

Testing

No Python API surface change. No dependency changes.

Testing

Performed:

  • Functional equivalence vs PyTorch SDPA backward across:
    • Multiple window sizes (32 → full length)
    • Causal and non-causal (causal path sampled)
    • Head dimension D=32 (baseline tuned path)
  • Randomized sparsity patterns (window-based masking)
  • Repeated runs (warmup + timed) to stabilize measurement
  • Checked for NaN/Inf emergence (none observed)
  • Verified no race by running with CUDA_LAUNCH_BLOCKING=1 and cuda-memcheck (spot tests)

Planned / Deferred:

  • Multi-architecture validation (SM80, SM90)
  • Larger head dims (64, 128)
  • Gradient equivalence harness extension to atomic-accum (Seq_parallel path)

Test Configuration

  • OS: (Container) Ubuntu base (inside docker)
  • Python: 3.12.x (inferred from .cpython-312 build)
  • PyTorch: 2.8.0a0+5228986c39.nv25.05
  • CUDA: Driver compatible with RTX 4090 (SM 89)
  • GPU: NVIDIA GeForce RTX 4090
  • Seed: 42
  • Runs: Warmup 2, Timed 3

Performance Impact

Benchmark Protocol

Command:

python benchmarks/backward_performance.py --test-type sdpa-vs-cuda

Mode: Local window attention sweeping window size (W) while keeping Q=K=16384, B=1, Hq=2, Hkv=1, D=32, causal (C).

Results (Averaged over 3 runs)

Config (B1 Hq2 Hkv1 Q16384 K16384 D32 C) Window (W) SDPA-BWD (ms) CUDA-BWD (ms) Speedup
32 28.42 4.64 6.1×
64 28.27 4.57 6.2×
128 28.36 4.67 6.1×
256 28.52 4.69 6.1×
512 28.67 5.28 5.4×
1024 28.47 4.83 5.9×
2048 28.56 4.86 5.9×
4096 28.39 4.91 5.8×
8192 28.37 5.68 5.0×
16384 (full) 28.50 5.87 4.9×

Observations:

  • SDPA baseline remains nearly flat (compute-bound softmax + matmuls).
  • Our backward stays ~4.6–5.0 ms for small/mid windows; degradation for large windows is expected as mask density approaches dense case and skip opportunities vanish.
  • Speedup highest at small to medium window (sparser effective attention region).
  • Removing redundant barriers shaved ~2–4% from initial tuned version (micro-profiling—numbers not shown here, can add if needed).

Before (Legacy)

No earlier native dynamic mask backward existed in-tree (baseline = PyTorch SDPA). If you want a “Before” column for internal review later, we can add raw times from pre-synchronization-slimming commit.

Breaking Changes

None.

  • API unchanged.
  • Memory footprint unchanged (no extra persistent buffers).
  • Determinism unaffected (barrier removals do not reorder floating ops across independent tiles).

Checklist

  • Code follows style guidelines
  • Self-reviewed
  • Critical sections commented (synchronization + alias rationale)
  • No new warnings in CUDA compilation (spot check)
  • Existing benchmarks pass
  • No perf regression (improvement verified)
  • Added doc note (will follow-up)
  • Multi-arch test (deferred)
  • Larger head dims validation (deferred)

CUDA-specific

  • Kernels compile without warnings (baseline environment)
  • Tested on SM 89 (RTX 4090)
  • Shared-memory aliasing guarded by barriers
  • Removed only provably redundant barriers
  • Memory usage profiling (deferred; expected neutral)
  • SM80 / SM90 validation (planned)

Additional Notes

Future work (tracked or to be filed):

  1. Double-buffering pipeline for backward (overlap next mask + dO load with current GEMMs).
  2. Bit-packed mask to reduce L2 bandwidth at mid sparsity.
  3. Instrumentation counters (tiles_skipped / tiles_total) for adaptive heuristics.
  4. Multi-head / high-D tuning (ensure register pressure stable when D=64/128).
  5. Extend Triton/Flex backward parity benchmarks once those variants support dynamic mask skip semantics.

Safety rationale for removed barriers:

  • Every removed __syncthreads() was immediately followed by a __syncthreads_or, or only guarded accesses to per-thread exclusive SMEM slices post cp_async_wait.
  • All GEMM consumption points still have an immediately preceding full CTA barrier ensuring tile completeness.

Reduces redundant mask copying and clearing operations by moving mask checks earlier in the loop and eliminating unnecessary register allocations for fully masked blocks.

Consolidates async fence calls to reduce synchronization overhead and removes redundant clear operations on accumulator fragments when blocks are skipped due to masking.
Adds early mask checking to determine if any threads are active before
performing expensive computations. This optimization prevents unnecessary
work when entire blocks are masked out, improving performance by:

- Moving mask evaluation earlier in the computation pipeline
- Conditionally executing bias loading and gemm operations
- Tracking active state across iterations to avoid redundant work
- Reducing memory transfers and computation overhead for masked regions

The change maintains correctness while significantly reducing wasted cycles
in attention backward pass kernels when dealing with padded sequences.
Tracks whether bias parameter was provided during forward pass and only returns bias gradient during backward pass when bias was originally given.

Prevents unnecessary computation and memory allocation for bias gradients when no bias is used in the attention mechanism.
Copilot AI review requested due to automatic review settings September 3, 2025 03:42

This comment was marked as outdated.

LoserCheems and others added 2 commits September 3, 2025 11:45
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@LoserCheems LoserCheems requested a review from Copilot September 3, 2025 03:45
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements a comprehensive optimization of the Flash Dynamic Mask Attention backward kernel, focusing on dynamic mask skip functionality and performance improvements. The changes eliminate race conditions, add early termination for fully masked blocks, and achieve 4.9×–6.2× speedups over PyTorch SDPA backward.

  • Dynamic mask skip with block-level OR reduction to avoid computation on fully masked tiles
  • Race condition fixes through proper shared memory aliasing and synchronization barrier optimization
  • Performance optimization by removing redundant barriers and adding conditional execution paths

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
flash_dmattn/flash_dmattn_interface.py Adds conditional dbias return logic to avoid unnecessary gradient computation when bias is None
csrc/src/flash_bwd_kernel.h Implements dynamic mask skip, fixes race conditions, and optimizes synchronization in backward CUDA kernel

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +423 to +425
return_dbias = True
if bias is None:
return_dbias = False
Copy link

Copilot AI Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic can be simplified to a single line: return_dbias = bias is not None. This makes the intent clearer and reduces code verbosity.

Suggested change
return_dbias = True
if bias is None:
return_dbias = False
return_dbias = bias is not None

Copilot uses AI. Check for mistakes.
Comment on lines 565 to 571
Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask);
bool any_active_local = false;
bool any_active_local_next = false; // to be updated later for next iteration
#pragma unroll
for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); }
bool any_active = __syncthreads_or(any_active_local);
bool any_active_next = any_active; // to be updated later for next iteration
Copy link

Copilot AI Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The variable any_active_next is initialized to any_active but then immediately reassigned later. Consider initializing it to false or adding a comment explaining why it needs this initial value.

Suggested change
Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask);
bool any_active_local = false;
bool any_active_local_next = false; // to be updated later for next iteration
#pragma unroll
for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); }
bool any_active = __syncthreads_or(any_active_local);
bool any_active_next = any_active; // to be updated later for next iteration
bool any_active_next = false; // to be updated later for next iteration

Copilot uses AI. Check for mistakes.
Comment on lines +773 to +774
// Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
// Convert dS from fp32 to fp16
Copy link

Copilot AI Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The commented-out line above this shows the original approach using dS.data(). Consider removing the commented line or adding a comment explaining why the direct acc_dp conversion approach was chosen for clarity.

Suggested change
// Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
// Convert dS from fp32 to fp16
// Previous approach: Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
// We now convert acc_dp directly to the target type for efficiency and correctness,
// avoiding potential issues with pointer aliasing and layout mismatches.

Copilot uses AI. Check for mistakes.
Initializes any_active_next to false instead of any_active to prevent
potential issues with unintended carry-over of active state between
iterations in the kernel loops.

Changes affect both forward and backward kernel implementations to
ensure consistent behavior across the codebase.
@LoserCheems LoserCheems merged commit 801e816 into main Sep 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] Reduce compute bubbles in backward skip path

8 participants