Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

@LoserCheems LoserCheems commented Aug 23, 2025

Description

This PR implements issue #113 by adding an initial integration layer for Flash Dynamic Mask Attention (FDMA) modeled after how Hugging Face transformers integrates Flash Attention (FA).
It introduces:

  1. A high-level forward wrapper (flash_dynamic_mask_attention_forward) mimicking the attention.forward override pattern in transformers.
  2. Lazy import and environment capability detection utilities.
  3. A utilities module that handles:
    • Unpadding / repadding (varlen) flows adapted to dynamic mask semantics.
    • Bias slicing & masking logic consistent with partial / cached KV scenarios.
    • PEFT dtype reconciliation (LayerNorm → fp32 casting rollback).
    • Switching between padded / varlen fused kernels using internal _flash_* function handles (mirroring FA style).
      This creates a clean adapter layer so downstream model code can select FDMA via config (_attn_implementation = "flash_dmattn") with minimal changes.

Type of Change

  • New feature (non-breaking change which adds functionality)
  • Performance optimization (expected lower memory traffic + better kernel utilization vs eager SDPA)
  • Bug fix
  • Breaking change
  • Documentation update
  • CUDA kernel improvement
  • Code refactoring

Related Issues

Changes Made

High-Level Summary

File Purpose
flash_dynamic_mask_attention.py Public forward wrapper: validates kwargs, handles transpose, dtype normalization, and dispatches to core FDMA logic.
import_utils.py Adds _is_package_available and is_flash_dmattn_available() for lazy, robust availability checks (GPU + package + torch).
modeling_flash_dynamic_mask_attention_utils.py Core integration utilities: varlen unpadding/padding, bias slicing, dtype alignment, lazy loading of internal FDMA kernels, main _flash_dynamic_mask_attention_forward dispatcher.

Key Implementation Points

  • Input Layout: Accepts non-transposed (batch, heads, seq, dim) from upstream; internally transposes to (batch, seq, heads, dim) for FDMA kernels—mirroring FA integration style.
  • Zero-Dimension Guard: Explicit error for any zero-sized dimension to avoid undefined CUDA behavior.
  • Target DType Recovery: Reverts fp32-casted inputs back to original (autocast or pre-quantization) dtype for performance and correctness.
  • Lazy Kernel Binding: _lazy_imports stashes _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn to globals only once; avoids import overhead when not selected.
  • Variable Length Flow:
    • _fdma_unpad_input / _fdma_pad_input replicate FA’s varlen pipeline but extended to combine an unused_mask (future extensibility).
    • _upad_input fuses Q/K/V/Bias processing to reduce redundant indexing passes.
    • Bias pruning applies row/column validity masks to prevent reading garbage past real sequence lengths under caching.
  • Causality Handling: is_causal is force-disabled for single-token decode steps to avoid unnecessary masking cost.
  • PEFT Integration Hook: fdma_peft_integration_check (skeleton) to ensure compatibility with LoRA / PEFT workflows where dtypes may shift.
  • Attn Bias Path: Supports attention_bias (e.g., ALiBi / rotary pre-computed bias) shaped like (batch, n_kv_heads, q_len, k_len) and slices intelligently under varlen.
  • Safety: Drops is_causal from kwargs before forwarding to prevent duplication conflicts.
  • Return Convention: Mirrors transformers returning (attn_output, None) when output_attentions=False.

Deferred / TODOs (Not blocking)

  • Flesh out fdma_peft_integration_check logic (currently skeletal in omitted lines).
  • Implement the actual _lazy_imports internal binding lines (placeholders where proprietary or unrevealed kernel symbols would be patched).
  • Extend support for output_attentions=True via a fallback path (documented limitation).
  • Add coverage for head_mask (warns & falls back today).

Documentation

  • Inline docstrings explain tensor shapes pre/post unpadding.
  • Added logging warnings for unsupported features.

Testing

Implemented / Verified Manually

  • Shape flow: Confirmed Q/K/V transpose + unpadding preserve head & head_dim ordering.
  • Error path: Zero-dimension guard triggers as expected.
  • Bias slicing: Verified no out-of-bounds when kv_seq_len > mask_len (caching scenario).

Recommended Automated Tests (to add)

  1. Functional Equivalence:
    • Compare FDMA vs baseline SDPA for random inputs (within tolerance) across:
      • (batch, seq) = (1, 1), (2, 37), (4, 512)
      • Head dims {64, 128}
      • With & without attention_bias
  2. Varlen / Ragged:
    • Uneven sequence lengths (e.g., mask lengths: [17, 5, 31]).
  3. Causal vs Non-Causal:
    • Causal decode step (q_len=1) ensures causality flag suppressed.
  4. DType Path:
    • AMP autocast fp16/bf16 vs eager fp32 baseline.
  5. Gradient Consistency:
    • Backward pass relative error < 1e-3 vs SDPA for small seeds.
  6. Multi-GPU readiness (if applicable):
    • Ensure no device mismatch in created tensors (col_idx, masks).
  7. Failure Cases:
    • Pass zero-length batch or seq to ensure ValueError triggers.

Example (proposed) PyTest Skeleton

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fdma_equiv(dtype):
    ...

Manual Run Commands (Windows PowerShell)

# Install editable + deps
pip install -e .
# (Optional) set CUDA env if needed
python -m pytest tests/test_fdma_integration.py -v

Performance Impact

Expected improvements:

  • Reduced memory bandwidth via unpadding before kernel launch.
  • Fewer Q/K/V indexing passes by fusing in _upad_input.
  • Avoids constructing large causal masks explicitly.
    Planned benchmark script alignment: could extend existing forward_performance.py with a --impl flash_dmattn flag.

Baseline expectation:

  • Parity or modest win over Flash Attention for irregular masks due to dynamic masking specialization.
  • Significant win over SDPA on long + sparse sequences.

(Concrete numbers to be added once benchmarks are executed.)

Breaking Changes

None. The integration is opt-in via configuration (_attn_implementation or selecting the forward override). No existing public API signatures altered.

Checklist

  • My code follows the project's style guidelines
  • I have performed a self-review
  • Added detailed docstrings / comments where logic is non-trivial
  • No new warnings introduced (to be validated in CI)
  • Added unit tests (pending)
  • Benchmarks collected (pending)
  • Lazy import pattern prevents overhead when unused
  • Documentation section in README referencing FDMA (follow-up)

CUDA-specific (indirect here; kernels assumed pre-existing):

  • Kernels compile without warnings (to confirm in CI)
  • Tested on SM 8.0+ (to run)
  • Memory usage profiled (future)
  • No leaks detected (future)

Additional Notes

  • Some sections show omitted lines because proprietary kernel hooks or still-in-progress functions were not expanded here; final PR should include those resolves.
  • Consider adding a config gate: allow_fdma_fallback to auto-fallback to SDPA when environment missing kernels.
  • Future: integrate attention score dropout (if required by training regimes) — currently assumed dropout=0/inference style.

Usage Example

from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward

# Inside a custom attention module:
attn_output, _ = flash_dynamic_mask_attention_forward(
    module=self,
    query=q, key=k, value=v,
    attention_mask=attn_mask,          # (batch, 1, 1, seq) or (batch, seq)
    attention_bias=alibi_or_none,      # (batch, n_kv_heads, q_len, k_len)
    scaling=self.scaling
)

Risk Assessment

  • Low: Code is additive & gated.
  • Main risk: Silent mismatch if _lazy_imports fails to bind and fallback path not yet implemented—mitigated by explicit logging (should ensure a logger.error in final version).
  • Edge: Large ragged batches could stress unpadding indexing; add micro-bench to confirm no regression.

Follow-Up Tasks

  1. Add full unit + integration test suite.
  2. Provide benchmark diff vs SDPA & Flash Attention.
  3. Expand README with configuration instructions.
  4. Complete PEFT dtype harmonization logic.
  5. Expose environment diagnostic util (python -m flash_dmattn.env_check style).

Screenshots / Benchmarks

(To be added after running forward_performance.py with representative configs.)

Cleans up test suite by removing demonstration scripts and test files
that were created to validate the varlen attention function bug fix.

Removes comprehensive test coverage for tensor shape validation,
memory efficiency improvements, and integration testing scenarios
that are no longer needed in the main codebase.
Introduces package availability checking with version detection and fallback mechanisms for special cases like torch dev versions.

Includes dedicated function to verify flash attention availability by checking torch, CUDA, and package dependencies.
Implements comprehensive utilities for flash dynamic mask attention operations including tensor padding/unpadding functions, input preprocessing, and attention computation workflows.

Provides FDMA-compatible functions for handling variable-length sequences with attention masks and supports both regular and variable-length flash attention variants.

Includes PEFT integration checks for dtype compatibility and lazy import mechanism for flexible implementation selection.
Implements forward function that integrates with transformers library
for flash attention with dynamic masking capabilities.

Handles input validation, tensor transposition, dtype casting for PEFT
compatibility, and delegates to core attention implementation with
proper parameter mapping.

Provides warning for unsupported features like output_attentions and
head_mask, directing users to eager attention mode when needed.
@LoserCheems LoserCheems requested review from Evanwu1125, SNHuan, Thanksyy, Copilot and wubingheng111 and removed request for Copilot August 23, 2025 10:08
@LoserCheems LoserCheems added the feature New feature request label Aug 23, 2025
@LoserCheems LoserCheems merged commit f567c1f into main Aug 23, 2025
@LoserCheems LoserCheems deleted the Fix-varlen-bug branch November 13, 2025 04:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] varlen example mask and bias wrong shapes

6 participants