Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • Fixes NaN in bias.grad when running on SM80 with both attn_mask and attn_bias enabled using the sample in case.py.

Root Cause

  • On SM80 with large head dims, a runtime branch forced a “single-split” fallback based on device SMEM. This diverged kernel paths between split/non-split selection and led to inconsistent shared-memory layouts/predicates for mask+bias tiles in the last N-block, producing corrupted attention scores that propagated to NaN in dbias.

Changes

  • Drops device SMEM query and the special-case fallback to a single split on limited-SMEM GPUs.
  • Unifies the forward path to the split-kv implementation and defers split selection to the caller, eliminating the risky runtime branch.
  • Touched code paths:
    • Forward API and launch path: FLASH_NAMESPACE::mha_fwd, run/dispatch helpers in flash_fwd_launch_template.h
    • Split-KV forward kernel path (mask+bias handling and predicates): FLASH_NAMESPACE::compute_attn_splitkv
  • Commit message:
    • Removes SMEM-based split-kv restriction
    • Drops device SMEM query and the special case that forced a single split on limited-SMEM GPUs for large head dims.
    • Simplifies the forward path and defers split selection to the caller, reducing runtime branching.

Reproduction

  • Env: SM80 GPU (Ampere), bf16/fp16.
  • Script: case.py
  • Shell:
    • python case.py
  • Observe before: for case "both" (use_bias=True, use_mask=True) → grad_bias_has_nan: True.
  • After fix: all out_has_nan and grad_*_has_nan are False across:
    • bias_only, mask_only, both, neither.

Tests

  • Added/updated unit to assert no NaN in outputs and grads for the four scenarios above with is_causal=True and deterministic=True using _flash_dynamic_mask_attention_forward from flash_dynamic_mask_attention.py.
  • Stress across representative head dims and varying key_length to cover last N-block predicate paths.
  • Validated numerical equivalence with reference attention for mask+bias where applicable.

Compatibility

  • No API signature changes.
  • Behavior note: split selection is now caller-driven. If code previously relied on implicit SMEM-based fallback, ensure an explicit num_splits/keep_window_size policy is provided upstream. Default behavior remains functional with num_splits=1; performance tuning may require an explicit split on large heads.

Checklist

  • [] Linked issue provided
  • Adds or updates tests
  • Updates docs if needed (note on split selection responsibility)
  • No perf regressions (removing branching reduces overhead; split choice remains tunable)

Drops device SMEM query and the special case that forced a single split on limited-SMEM GPUs for large head dims.

Simplifies the forward path and defers split selection to the caller, reducing runtime branching.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a NaN bug in bias gradients when running on SM80 GPUs with both attention mask and bias enabled. The issue stemmed from a runtime branch that forced a "single-split" fallback based on device shared memory limitations, causing inconsistent kernel paths and corrupted attention scores.

  • Removes device SMEM query and special-case fallback logic for limited-SMEM GPUs
  • Unifies the forward path to use split-kv implementation consistently
  • Defers split selection responsibility to the caller to eliminate risky runtime branching

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@LoserCheems LoserCheems merged commit af2ec35 into main Sep 22, 2025
3 of 4 checks passed
@LoserCheems LoserCheems deleted the fix-nan-in-sm80 branch November 13, 2025 04:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants