Integrate Flash Dynamic Mask Attention (FDMA) Into Transformers-Style Attention Flow #118
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR implements issue #113 by adding an initial integration layer for Flash Dynamic Mask Attention (FDMA) modeled after how Hugging Face
transformersintegrates Flash Attention (FA).It introduces:
flash_dynamic_mask_attention_forward) mimicking theattention.forwardoverride pattern intransformers._flash_*function handles (mirroring FA style).This creates a clean adapter layer so downstream model code can select FDMA via config (
_attn_implementation = "flash_dmattn") with minimal changes.Type of Change
Related Issues
Changes Made
High-Level Summary
_is_package_availableandis_flash_dmattn_available()for lazy, robust availability checks (GPU + package + torch)._flash_dynamic_mask_attention_forwarddispatcher.Key Implementation Points
_lazy_importsstashes_flash_fn,_flash_varlen_fn,_pad_fn,_unpad_fnto globals only once; avoids import overhead when not selected._fdma_unpad_input/_fdma_pad_inputreplicate FA’s varlen pipeline but extended to combine anunused_mask(future extensibility)._upad_inputfuses Q/K/V/Bias processing to reduce redundant indexing passes.is_causalis force-disabled for single-token decode steps to avoid unnecessary masking cost.fdma_peft_integration_check(skeleton) to ensure compatibility with LoRA / PEFT workflows where dtypes may shift.attention_bias(e.g., ALiBi / rotary pre-computed bias) shaped like(batch, n_kv_heads, q_len, k_len)and slices intelligently under varlen.is_causalfrom kwargs before forwarding to prevent duplication conflicts.transformersreturning(attn_output, None)whenoutput_attentions=False.Deferred / TODOs (Not blocking)
fdma_peft_integration_checklogic (currently skeletal in omitted lines)._lazy_importsinternal binding lines (placeholders where proprietary or unrevealed kernel symbols would be patched).output_attentions=Truevia a fallback path (documented limitation).head_mask(warns & falls back today).Documentation
Testing
Implemented / Verified Manually
kv_seq_len > mask_len(caching scenario).Recommended Automated Tests (to add)
attention_biascol_idx, masks).Example (proposed) PyTest Skeleton
Manual Run Commands (Windows PowerShell)
Performance Impact
Expected improvements:
_upad_input.Planned benchmark script alignment: could extend existing forward_performance.py with a
--impl flash_dmattnflag.Baseline expectation:
(Concrete numbers to be added once benchmarks are executed.)
Breaking Changes
None. The integration is opt-in via configuration (
_attn_implementationor selecting the forward override). No existing public API signatures altered.Checklist
CUDA-specific (indirect here; kernels assumed pre-existing):
Additional Notes
allow_fdma_fallbackto auto-fallback to SDPA when environment missing kernels.Usage Example
Risk Assessment
_lazy_importsfails to bind and fallback path not yet implemented—mitigated by explicit logging (should ensure alogger.errorin final version).Follow-Up Tasks
python -m flash_dmattn.env_checkstyle).Screenshots / Benchmarks
(To be added after running forward_performance.py with representative configs.)