Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Introduce options for accumulating bias gradients during backward attention, improving performance and configurability. Fix handling of broadcasted bias to ensure correct accumulation and numerical stability.

Introduces a toggle to optionally accumulate bias gradients during backward attention.
Enables skipping unnecessary dbias work when unused and provides clearer control for kernels, aiding performance and configurability.
Adds an optional accumulation path for bias gradients using atomic updates when accumulation is enabled, avoiding overwrites when multiple tiles contribute.

Keeps the existing fast write path when accumulation is disabled, respects sequence bounds, and correctly tracks the accumulation pointer across tile steps.

Improves correctness for split/streamed backward passes where bias gradients must be aggregated across blocks.
Improves backward handling when bias is broadcast across sequence or batch by allocating correctly shaped scratch buffers and adjusting reduction paths. Adds a kernel parameter to accumulate along sequence for S=1 bias, and uses fp32 buffers for numerically stable accumulation.

Corrects the previous over-eager scratch allocation on batch-size mismatch to only trigger for shared (B=1) or head-grouped cases, aligning with broadcasting semantics (incl. MQA/GQA). Leaves the variable-length path unchanged (no accumulation).

Results in correct dbias reductions and gradients for broadcasted bias with better numerical stability.
@LoserCheems
Copy link
Collaborator Author

before: [cuda bwd] alloc_before=187.34 MB, alloc_after=189.31 MB, peak_alloc=448.38 MB, peak_reserved=662.00 MB

after: [cuda bwd] alloc_before=187.34 MB, alloc_after=189.31 MB, peak_alloc=320.43 MB, peak_reserved=662.00 MB

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Introduce configurable accumulation for bias gradients in the backward attention path, improving correctness for broadcasted bias and enabling optional in-kernel accumulation.

  • Add accum_dbias flag to backward params/API.
  • Implement optional atomic accumulation of dBias in the kernel for broadcasted-seqlen_q biases.
  • Adjust dbias buffer allocation and post-kernel reduction logic to handle broadcasted seqlen_q=1 with float accumulation.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
csrc/flash_dmattn/src/flash_bwd_kernel.h Adds optional atomic accumulation path for dBias and maintains pointer increments when traversing M blocks.
csrc/flash_dmattn/src/flash.h Extends Flash_bwd_params with accum_dbias flag.
csrc/flash_dmattn/flash_api.cpp Plumbs accum_dbias through set_params_dgrad, updates dbias allocation/reduction logic for broadcasted seqlen_q=1, and sets accum_dbias in mha_bwd.

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines 983 to 989
? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q
? (seqlen_q_bias == 1)
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat))
: torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
: dbias
: torch::empty({0}, opts);
bool accum_dbias = has_bias && seqlen_q_bias != seqlen_q && seqlen_q_bias == 1;
Copy link

Copilot AI Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new condition only allocates an expanded buffer when batch_size_bias == 1 or seqlen_q_bias == 1 (broadcast), but drops the previous handling for general mismatches (batch_size_bias != batch_size or seqlen_q_bias != seqlen_q where the size is not 1). This can cause the kernel to write into a tensor with incompatible shape/strides when dbias dimensions mismatch but are not broadcastable (e.g., batch_size_bias=2 vs batch_size=4, seqlen_q_bias=64 vs seqlen_q=128). Restore the general mismatch allocation while keeping the float accumulation path for seqlen_q_bias == 1. For example, keep the outer mismatch check as (num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) and specialize the inner allocation to use float only when seqlen_q_bias == 1.

Copilot uses AI. Check for mistakes.
Comment on lines +863 to +881
#pragma unroll
for (int m = 0; m < size<1>(tBiassBias); ++m) {
if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
#pragma unroll
for (int n = 0; n < size<2>(tBiassBias); ++n) {
if (Is_even_MN || tBiaspBias(n)) {
#pragma unroll
for (int i = 0; i < size<0>(tBiassBias); ++i) {
const auto coord = tBiascBias(i, m, n);
const int row = get<0>(coord);
const int col = get<1>(coord);
if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) {
atomicAdd(
gdBias_accum_ptr + row * params.dbias_row_stride + col,
static_cast<ElementAccum>(tBiassBias(i, m, n))
);
}
}
}
Copy link

Copilot AI Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This path performs an atomicAdd per element, which can cause significant contention when seqlen_q is large (many M-tiles accumulate into the same broadcasted bias row). Consider reducing within the threadblock first (e.g., per-(row,col) partial sums in shared memory or warp-level reductions) and issuing a single atomicAdd per (row,col) per block. This typically cuts the number of atomics by a factor of size<0>(tBiassBias) and improves throughput.

Copilot uses AI. Check for mistakes.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@LoserCheems LoserCheems merged commit 78bb93d into main Oct 16, 2025
@LoserCheems LoserCheems deleted the fix-189 branch October 27, 2025 08:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants