-
Notifications
You must be signed in to change notification settings - Fork 39
Enhance bias gradient accumulation in backward pass #193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduces a toggle to optionally accumulate bias gradients during backward attention. Enables skipping unnecessary dbias work when unused and provides clearer control for kernels, aiding performance and configurability.
Adds an optional accumulation path for bias gradients using atomic updates when accumulation is enabled, avoiding overwrites when multiple tiles contribute. Keeps the existing fast write path when accumulation is disabled, respects sequence bounds, and correctly tracks the accumulation pointer across tile steps. Improves correctness for split/streamed backward passes where bias gradients must be aggregated across blocks.
Improves backward handling when bias is broadcast across sequence or batch by allocating correctly shaped scratch buffers and adjusting reduction paths. Adds a kernel parameter to accumulate along sequence for S=1 bias, and uses fp32 buffers for numerically stable accumulation. Corrects the previous over-eager scratch allocation on batch-size mismatch to only trigger for shared (B=1) or head-grouped cases, aligning with broadcasting semantics (incl. MQA/GQA). Leaves the variable-length path unchanged (no accumulation). Results in correct dbias reductions and gradients for broadcasted bias with better numerical stability.
|
before: [cuda bwd] alloc_before=187.34 MB, alloc_after=189.31 MB, peak_alloc=448.38 MB, peak_reserved=662.00 MB after: [cuda bwd] alloc_before=187.34 MB, alloc_after=189.31 MB, peak_alloc=320.43 MB, peak_reserved=662.00 MB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Introduce configurable accumulation for bias gradients in the backward attention path, improving correctness for broadcasted bias and enabling optional in-kernel accumulation.
- Add accum_dbias flag to backward params/API.
- Implement optional atomic accumulation of dBias in the kernel for broadcasted-seqlen_q biases.
- Adjust dbias buffer allocation and post-kernel reduction logic to handle broadcasted seqlen_q=1 with float accumulation.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| csrc/flash_dmattn/src/flash_bwd_kernel.h | Adds optional atomic accumulation path for dBias and maintains pointer increments when traversing M blocks. |
| csrc/flash_dmattn/src/flash.h | Extends Flash_bwd_params with accum_dbias flag. |
| csrc/flash_dmattn/flash_api.cpp | Plumbs accum_dbias through set_params_dgrad, updates dbias allocation/reduction logic for broadcasted seqlen_q=1, and sets accum_dbias in mha_bwd. |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
csrc/flash_dmattn/flash_api.cpp
Outdated
| ? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q | ||
| ? (seqlen_q_bias == 1) | ||
| ? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat)) | ||
| : torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) | ||
| : dbias | ||
| : torch::empty({0}, opts); | ||
| bool accum_dbias = has_bias && seqlen_q_bias != seqlen_q && seqlen_q_bias == 1; |
Copilot
AI
Oct 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new condition only allocates an expanded buffer when batch_size_bias == 1 or seqlen_q_bias == 1 (broadcast), but drops the previous handling for general mismatches (batch_size_bias != batch_size or seqlen_q_bias != seqlen_q where the size is not 1). This can cause the kernel to write into a tensor with incompatible shape/strides when dbias dimensions mismatch but are not broadcastable (e.g., batch_size_bias=2 vs batch_size=4, seqlen_q_bias=64 vs seqlen_q=128). Restore the general mismatch allocation while keeping the float accumulation path for seqlen_q_bias == 1. For example, keep the outer mismatch check as (num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) and specialize the inner allocation to use float only when seqlen_q_bias == 1.
| #pragma unroll | ||
| for (int m = 0; m < size<1>(tBiassBias); ++m) { | ||
| if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { | ||
| #pragma unroll | ||
| for (int n = 0; n < size<2>(tBiassBias); ++n) { | ||
| if (Is_even_MN || tBiaspBias(n)) { | ||
| #pragma unroll | ||
| for (int i = 0; i < size<0>(tBiassBias); ++i) { | ||
| const auto coord = tBiascBias(i, m, n); | ||
| const int row = get<0>(coord); | ||
| const int col = get<1>(coord); | ||
| if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) { | ||
| atomicAdd( | ||
| gdBias_accum_ptr + row * params.dbias_row_stride + col, | ||
| static_cast<ElementAccum>(tBiassBias(i, m, n)) | ||
| ); | ||
| } | ||
| } | ||
| } |
Copilot
AI
Oct 16, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This path performs an atomicAdd per element, which can cause significant contention when seqlen_q is large (many M-tiles accumulate into the same broadcasted bias row). Consider reducing within the threadblock first (e.g., per-(row,col) partial sums in shared memory or warp-level reductions) and issuing a single atomicAdd per (row,col) per block. This typically cuts the number of atomics by a factor of size<0>(tBiassBias) and improves throughput.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Introduce options for accumulating bias gradients during backward attention, improving performance and configurability. Fix handling of broadcasted bias to ensure correct accumulation and numerical stability.