-
Notifications
You must be signed in to change notification settings - Fork 39
Consolidate mask and bias memory operations #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Replaces separate SmemLayoutAtomMask and SmemLayoutAtomBias with a single SmemLayoutAtomPS to reduce code duplication and improve memory layout consistency. Introduces kPBlockN parameter with configurable block sizes (16, 32, or 64) and dedicated swizzle pattern for better memory access patterns. Unifies global memory copy operations for mask and bias into a single GmemTiledCopyMaskBias with improved 128-bit alignment and 8 values per read.
Consolidates separate mask and bias memory copy operations into a unified approach by replacing distinct layout types and copy operations with shared AtomPS layout and MaskBias copy operations. Reduces code duplication and improves memory access patterns by using the same layout configuration for both mask and bias tensors in shared memory operations.
Eliminates the kPBlockN constant and its static assertion since it was not being used effectively in the layout configuration. Simplifies the SmemLayoutAtomPS composition by directly using kBlockN instead of the intermediate kPBlockN parameter.
Improves kernel selection logic by dynamically querying device shared memory limits instead of relying on hardcoded compute capability checks. Uses actual shared memory per block availability to determine optimal kernel configurations, enabling better performance across different GPU architectures without requiring architecture-specific branching logic. Simplifies the codebase by removing compute capability detection and associated conditional logic while maintaining performance optimization goals.
Fixes dynamic mask preparation to properly handle invalid topk values by checking against minimum dtype values before scattering to attention mask. Expands benchmark test configurations to include comprehensive coverage across multiple head dimensions (32, 64, 96, 128, 256) and sequence lengths, providing more thorough validation of attention mechanisms. Re-enables previously disabled triton and flex attention test suites to ensure complete equivalence testing across all implementation variants.
Increases block size from 64 to 128 for head dimensions <= 32 to improve memory throughput and computational efficiency for smaller attention heads. The nested ternary operator now handles three cases: - Head dim <= 32: uses 128 block size - Head dim <= 64: uses 64 block size - Head dim >= 128: uses 32 block size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR consolidates mask and bias memory operations to unify their layouts and reduce code duplication. The changes improve memory access patterns through configurable block sizes and optimize kernel selection logic for better performance across GPU architectures.
- Consolidate mask and bias memory layouts into a unified "PS" (Pointwise-Sparse) layout system
- Replace architecture-based kernel selection with dynamic shared memory allocation checking
- Enhance test coverage with expanded configurations and re-enable previously disabled tests
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| csrc/src/kernel_traits.h | Unifies mask and bias layouts into SmemLayoutAtomPS, consolidates memory copy operations |
| csrc/src/flash_fwd_launch_template.h | Replaces compute capability checks with shared memory size checks for kernel selection |
| csrc/src/flash_fwd_kernel.h | Updates kernel code to use unified mask/bias copy operations and layout types |
| benchmarks/forward_equivalence.py | Expands test coverage with additional head dimensions and re-enables triton/flex tests |
Comments suppressed due to low confidence (4)
csrc/src/kernel_traits.h:221
- Using SmemLayoutAtomPS instead of SmemLayoutPS for mask tensor creation. This creates a layout mismatch since SmemLayoutAtomPS is the atomic layout while SmemLayoutPS is the full tiled layout that should be used for tensor creation.
int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2,
csrc/src/kernel_traits.h:225
- Using SmemLayoutAtomPS instead of SmemLayoutPS for bias tensor creation. This creates a layout mismatch since SmemLayoutAtomPS is the atomic layout while SmemLayoutPS is the full tiled layout that should be used for tensor creation.
struct Flash_bwd_kernel_traits : public Base {
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| (1, 2, 1, 4096, 4096, 96, False), | ||
|
|
||
| (1, 2, 1, 128, 128, 128, True), | ||
| (1, 2, 1, 128, 128, 128, True), |
Copilot
AI
Aug 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate test configuration. The same test case (1, 2, 1, 128, 128, 128, True) appears twice in the test_configs list, which is redundant and increases test execution time unnecessarily.
| (1, 2, 1, 128, 128, 128, True), |
| (1, 2, 1, 4096, 4096, 96, False), | ||
|
|
||
| (1, 2, 1, 128, 128, 128, True), | ||
| (1, 2, 1, 128, 128, 128, True), |
Copilot
AI
Aug 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate test configuration in triton test function. The same test case (1, 2, 1, 128, 128, 128, True) appears twice in the test_configs list, which is redundant and increases test execution time unnecessarily.
| (1, 2, 1, 128, 128, 128, True), |
| (1, 2, 1, 4096, 4096, 96, False), | ||
|
|
||
| (1, 2, 1, 128, 128, 128, True), | ||
| (1, 2, 1, 128, 128, 128, True), |
Copilot
AI
Aug 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate test configuration in flex test function. The same test case (1, 2, 1, 128, 128, 128, True) appears twice in the test_configs list, which is redundant and increases test execution time unnecessarily.
| (1, 2, 1, 128, 128, 128, True), |
| (1, 2, 1, 128, 128, 128, True), | ||
| (1, 2, 1, 128, 128, 128, False), |
Copilot
AI
Aug 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate test configuration in flex test function. The test case (1, 2, 1, 128, 128, 128, True) appears again, and there's also a duplicate (1, 2, 1, 128, 128, 128, False) that should be removed to avoid redundant testing.
| (1, 2, 1, 128, 128, 128, True), | |
| (1, 2, 1, 128, 128, 128, False), | |
| # (1, 2, 1, 128, 128, 128, True), # Removed duplicate | |
| # (1, 2, 1, 128, 128, 128, False), # Removed duplicate |
Enhances the prepare_dynamic_mask function by capturing both values and indices from torch.topk operation, then filtering out invalid entries based on minimum dtype values. This prevents invalid indices from being included in the attention mask. Updates inline comments to standardize "INF" terminology and removes outdated debugging comments from test cases.
Improves the dynamic mask preparation by properly handling cases where topk values are invalid (equal to minimum dtype value). Previously, the mask would incorrectly include positions with invalid attention scores, potentially causing incorrect attention computations. Now validates topk values before setting mask positions, ensuring only valid attention scores are preserved in the final mask.
Unify mask and bias layouts and memory operations to reduce code duplication and improve memory access patterns. Introduce configurable block sizes and optimize kernel selection logic for better performance across GPU architectures. Enhance attention mask handling and expand test coverage for comprehensive validation. Adjust block sizes for improved efficiency with smaller head dimensions.