Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

@LoserCheems LoserCheems commented Aug 26, 2025

Fix: #121
Enhance masking logic for exact value matching, optimize memory layout for mask and bias tensors, and streamline shared memory operations to improve performance and readability. Address issues with sparsity detection and bias gradient pointer management for better computational efficiency. Cleanup includes removing unused code and adding type hints for clarity.

Updates mask condition logic to use exact equality comparison instead of less-than-or-equal, ensuring only zero values trigger masking behavior.

Removes large block of commented-out alternative implementation code to clean up the codebase and improve readability.
Changes comparison from > 0 to != 0.0f to properly detect all non-zero elements including negative values in sparse matrix operations.

Previously only positive values were considered active, which could lead to incorrect sparsity patterns when matrices contain negative elements.
Moves the bias gradient pointer advancement to the prologue section to ensure proper memory alignment and consistent pointer management throughout the computation loop.

Changes the clear out-of-bounds flag to true for improved memory safety when copying bias gradients to global memory.
Merges separate mask and bias tiled copy types into a unified implementation.

Increases alignment from 64 to 128 bits and vectorization from 4 to 8 values per operation, improving memory bandwidth utilization.
Unifies separate mask and bias global memory tiled copy objects into a single shared copy handler to reduce memory overhead and improve kernel efficiency.

Adds missing tensor partition for bias gradient computation to ensure proper memory layout handling during backward pass operations.
Removes redundant layout definitions and reuses existing layout structure to reduce code duplication and memory overhead.

Previously defined separate layout atoms and arrangements which duplicated the same configuration as the existing PdS layout.
Consolidates shared memory layout usage by replacing separate SmemLayoutMask and SmemLayoutBias with SmemLayoutPdS for both mask and bias tensors.

Removes redundant sdBias tensor and associated copy operations, streamlining memory management and reducing code duplication.

Reorganizes bias copying to occur after softcap application, improving computational flow and memory access patterns.
Eliminates kSmemdSSize variable and its usage in memory calculations to reduce shared memory footprint in backward kernel.

Comments out the unused variable definition and removes it from both kSmemSize and kSmemSize1colblock calculations, optimizing memory usage without affecting functionality.
Adjusts shared memory thresholds and kernel trait parameters across different head dimensions to improve performance on H100 and A100 GPUs.

Reduces memory requirements while maintaining or improving computational efficiency by fine-tuning block sizes, memory layout parameters, and GPU-specific optimizations.

Consolidates some GPU target categories where similar configurations work well across both H100 and A100 architectures.
Swaps memory allocation order between gradient and probability tensors to optimize memory usage pattern.

Changes sdS tensors to use bias data directly while moving sP tensor to use the previous sdS location, maintaining total memory footprint while improving data locality.
Removes commented-out debugging breakpoints that are no longer needed for development.

Adds type annotations to improve code clarity and enable better static analysis.

Fixes inconsistent return tuple lengths in backward functions to match expected parameter counts.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes the FlashAttention backward kernel implementation by improving memory management and computational efficiency. The changes focus on consolidating shared memory layouts for mask and bias tensors, refining sparsity detection logic, and streamlining memory operations.

Key Changes:

  • Enhanced masking logic to use exact equality comparison (== 0.0f) instead of inequality (<= 0) for improved precision
  • Unified memory layout for mask and bias tensors to reduce shared memory usage and simplify memory management
  • Adjusted shared memory size calculations and kernel configurations to optimize for different GPU architectures

Reviewed Changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
flash_dmattn/flash_dmattn_interface.py Added type hints and removed debug code, adjusted return tuple sizes
csrc/src/utils.h Updated sparsity detection comparison from > 0 to != 0.0f
csrc/src/mask.h Changed mask comparison from <= 0.0f to == 0.0f and removed commented code
csrc/src/kernel_traits.h Unified mask and bias memory layouts, removed separate layouts and updated memory size calculations
csrc/src/flash_bwd_launch_template.h Updated shared memory thresholds and kernel configurations for different GPU architectures
csrc/src/flash_bwd_kernel.h Streamlined shared memory operations, unified copy operations for mask/bias, and optimized memory pointer management

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.



def maybe_contiguous(x):
def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint imports Optional but the import statement is missing. Add from typing import Optional at the top of the file.

Copilot uses AI. Check for mistakes.
for (int n = 0; n < size<2>(tCrM) && !local_any_active; ++n) {
// Use direct comparison to avoid potential branching
local_any_active |= (tCrM(mma, m, n) > 0);
local_any_active |= (tCrM(mma, m, n) != 0.0f);
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using exact equality comparison with floating-point values can be problematic due to precision issues. Consider using a small epsilon for comparison: abs(tCrM(mma, m, n)) > epsilon where epsilon is a small threshold like 1e-9f.

Suggested change
local_any_active |= (tCrM(mma, m, n) != 0.0f);
local_any_active |= (fabsf(tCrM(mma, m, n)) > FLASH_EPSILON);

Copilot uses AI. Check for mistakes.
Comment on lines 304 to +305
// Use direct comparison to avoid potential branching
local_any_active |= (tCrM(mma, m, n) > 0);
local_any_active |= (tCrM(mma, m, n) != 0.0f);
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using exact equality comparison with floating-point values can be problematic due to precision issues. Consider using a small epsilon for comparison: abs(tCrM(mma, m, n)) > epsilon where epsilon is a small threshold like 1e-9f.

Copilot uses AI. Check for mistakes.
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
// Apply scaling and bias or masking
tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) <= 0.0f)
tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f)
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using exact equality comparison with floating-point values can be problematic due to precision issues. Consider using a small epsilon for comparison or ensuring the mask values are exactly 0.0f or 1.0f.

Suggested change
tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f)
tensor(coord) = (col_idx >= col_idx_limit) || (fabsf(mask(coord)) < EPSILON)

Copilot uses AI. Check for mistakes.
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
// Apply scaling and bias or masking
tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) <= 0.0f)
tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f)
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using exact equality comparison with floating-point values can be problematic due to precision issues. Consider using a small epsilon for comparison or ensuring the mask values are exactly 0.0f or 1.0f.

Suggested change
tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f)
tensor(coord) = (col_idx >= col_idx_limit) || (fabsf(mask(coord)) < EPSILON)

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 7f727ab into main Aug 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] NaN / Inf values appear only in dV during backward pass

6 participants