-
Couldn't load subscription status.
- Fork 287
[Example] Introduce split+sum template, and optimize atomic_add performance for bwd examples
#940
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc @Rachmanino |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdds an atomic-add backward kernel and retains a split backward variant across multiple FlashAttention examples, exposing a new use_atomic flag wired through autograd, main, and CLI. Refactors TileLang region helpers into utils with broader Buffer/BufferLoad/BufferRegion support and improves GEMM ICHECK diagnostics. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant CLI as CLI
participant Main as main(...)
participant AttnFwd as _attention.forward
participant Autograd as _attention.backward
participant Atomic as flashattn_bwd_atomic_add
participant Split as flashattn_bwd_split
CLI->>Main: parse args (--use_atomic / --use_split)
Main->>AttnFwd: attention(Q,K,V, causal, groups, use_atomic)
AttnFwd-->>Main: O (output) and ctx.use_atomic stored
Note over Autograd: Backward invoked by autograd
Autograd->>Autograd: if ctx.use_atomic == True
Autograd-->>Atomic: call atomic-add kernel
Atomic-->>Autograd: return dq, dk, dv
Autograd->>Autograd: else
Autograd-->>Split: call split kernel
Split-->>Autograd: return dq, dk, dv
Autograd-->>Main: gradients (dq, dk, dv)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tilelang/utils/language.py (1)
135-140: UseTypeErrorand format the message cleanly.Switching to
TypeError(and formatting the message instead of passing a tuple) aligns with the intent of signaling an invalid index type and satisfies TRY004/TRY003.- else: - raise ValueError("Unsupported type: ", type(indice)) + else: + raise TypeError(f"Unsupported index type: {type(indice)}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
examples/flash_attention/example_gqa_bwd.py(9 hunks)examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py(9 hunks)examples/flash_attention/example_mha_bwd.py(9 hunks)examples/flash_attention/example_mha_bwd_wgmma_pipelined.py(9 hunks)src/op/gemm.cc(2 hunks)tilelang/language/atomic.py(2 hunks)tilelang/language/copy.py(1 hunks)tilelang/language/customize.py(1 hunks)tilelang/language/utils.py(1 hunks)tilelang/utils/language.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (8)
tilelang/language/atomic.py (3)
tilelang/language/utils.py (4)
buffer_to_tile_region(30-42)buffer_region_to_tile_region(73-90)buffer_load_to_tile_region(45-70)region(8-27)tilelang/utils/language.py (1)
get_buffer_region_from_load(124-141)tilelang/language/frame.py (2)
has_let_value(189-198)get_let_value(201-210)
tilelang/language/utils.py (1)
tilelang/language/tir/op.py (1)
call_intrin(119-144)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (1)
examples/flash_attention/example_gqa_bwd.py (7)
flashattn_bwd_atomic_add(150-245)flash_bwd(171-243)flash_bwd(274-349)make_dq_layout(116-119)flashattn_bwd_split(251-351)forward(358-368)main(466-524)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (1)
examples/flash_attention/example_gqa_bwd.py (6)
flashattn_bwd_atomic_add(150-245)flash_bwd(171-243)flash_bwd(274-349)make_dq_layout(116-119)forward(358-368)flashattn_fwd(12-80)
examples/flash_attention/example_mha_bwd.py (4)
examples/flash_attention/example_gqa_bwd.py (5)
flashattn_bwd_atomic_add(150-245)flash_bwd(171-243)flash_bwd(274-349)make_dq_layout(116-119)flashattn_bwd_split(251-351)examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (5)
flashattn_bwd_atomic_add(149-249)flash_bwd(165-247)flash_bwd(271-356)make_dq_layout(115-118)flashattn_bwd_split(255-358)tilelang/language/kernel.py (2)
threads(195-199)Kernel(209-258)tilelang/language/atomic.py (1)
atomic_add(116-228)
tilelang/language/copy.py (1)
tilelang/language/utils.py (3)
buffer_to_tile_region(30-42)buffer_region_to_tile_region(73-90)buffer_load_to_tile_region(45-70)
tilelang/utils/language.py (1)
tilelang/language/ast/ir.py (1)
Range(1716-1728)
examples/flash_attention/example_gqa_bwd.py (1)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (9)
flashattn_bwd_atomic_add(150-256)flash_bwd(171-254)flash_bwd(285-371)make_dq_layout(116-119)flashattn_bwd_split(262-373)forward(380-390)flashattn_bwd_preprocess(87-113)flashattn_bwd_postprocess(126-144)main(488-546)
🪛 Ruff (0.13.3)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
515-515: Ambiguous variable name: O
(E741)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
452-452: Ambiguous variable name: O
(E741)
examples/flash_attention/example_mha_bwd.py
431-431: Ambiguous variable name: O
(E741)
tilelang/utils/language.py
140-140: Prefer TypeError exception for invalid type
(TRY004)
140-140: Avoid specifying long messages outside the exception class
(TRY003)
examples/flash_attention/example_gqa_bwd.py
493-493: Ambiguous variable name: O
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: format-check
🔇 Additional comments (2)
src/op/gemm.cc (2)
288-290: LGTM! Enhanced diagnostic output.The error message now includes the actual values of
m_warp,n_warp, andnum_warps, which will significantly aid debugging when this assertion fails.
374-377: LGTM! Additional validation strengthens robustness.This new assertion provides an extra safety check for the non-WGMMA code path, complementing the similar check in the WGMMA branch. The enhanced error message will help diagnose issues if the warp partition computation logic produces invalid results.
| for i, j in T.Parallel(block_M, dim_qk): | ||
| T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard dK atomics on tail blocks
When seq_len isn’t an exact multiple of block_M, the tail block will execute this loop with by * block_M + i >= seq_len, so we end up atomically writing past the end of dK. Please gate the atomic add the same way we already guard the dQ updates.
- for i, j in T.Parallel(block_M, dim_qk):
- T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])
+ for i, j in T.Parallel(block_M, dim_qk):
+ if by * block_M + i < seq_len:
+ T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for i, j in T.Parallel(block_M, dim_qk): | |
| T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j]) | |
| for i, j in T.Parallel(block_M, dim_qk): | |
| if by * block_M + i < seq_len: | |
| T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j]) |
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around lines 253
to 255, the atomic_add to dK can write past the end when seq_len is not a
multiple of block_M; wrap the T.atomic_add with the same guard used for dQ
(i.e., check that by * block_M + i < seq_len before performing the atomic_add)
so tail-block iterations that exceed seq_len are skipped.
| T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) | ||
|
|
||
| return flash_bwd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bounds-check dK atomic adds
For non-multiple seq_len, by * block_M + i can run past the valid range and we’ll issue atomics outside dK. Please mirror the guard used for dQ so the tail block lands safely.
- for i, j in T.Parallel(block_M, dim_qk):
- T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])
+ for i, j in T.Parallel(block_M, dim_qk):
+ if by * block_M + i < seq_len:
+ T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd.py around lines 243 to 245, the
T.atomic_add into dK can write past the buffer when seq_len is not a multiple of
block_M; mirror the guard used for dQ by bounding the by*block_M:(by+1)*block_M
range so it never exceeds seq_len. Specifically, compute the valid tail length
(e.g., end = min((by+1)*block_M, seq_len)) or apply a per-element
mask/conditional so atomic_add only accumulates into dK[bz, by*block_M:end,
bx//groups, :] (or only perform atomic_add for indices i where by*block_M + i <
seq_len), ensuring the tail block is safely truncated before performing the
atomic adds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/language/utils.py(1 hunks)tilelang/utils/language.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/utils.py (1)
tilelang/language/tir/op.py (1)
call_intrin(119-144)
🪛 Ruff (0.13.3)
tilelang/utils/language.py
142-142: Prefer TypeError exception for invalid type
(TRY004)
142-142: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: format-check
🔇 Additional comments (4)
tilelang/language/utils.py (4)
1-6: LGTM!The imports are appropriate and all imported types are used in the new region utility functions.
8-27: LGTM!The function correctly maps access types to numeric codes and constructs the region intrinsic call. The docstring clearly documents the behavior including the KeyError for invalid access types.
30-42: LGTM!The function correctly converts a Buffer to a tile region descriptor by creating a region that spans the entire buffer.
72-89: LGTM!The function correctly converts a BufferRegion to a tile region descriptor by extracting mins and extents, with appropriate validation.
| # (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, " | ||
| # f"region will be expanded in the last 2 dimensions") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Remove commented-out code.
The commented-out code should either be removed or converted to proper logging if the information is needed for debugging.
Apply this diff:
- if len(indices) > len(extents):
- # (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, "
- # f"region will be expanded in the last 2 dimensions")
- new_extents = []
+ if len(indices) > len(extents):
+ new_extents = []📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, " | |
| # f"region will be expanded in the last 2 dimensions") | |
| if len(indices) > len(extents): | |
| new_extents = [] |
🤖 Prompt for AI Agents
In tilelang/language/utils.py around lines 59-60, there is commented-out code
that should be removed or turned into proper logging; delete the two commented
lines or replace them with a call to the project's logger (e.g., logger.debug or
logger.warning) that formats the same message with variables (indices, extents,
load) so runtime diagnostics are preserved, ensuring you import/use the module
logger consistently with the codebase.
| for extent in extents: | ||
| new_extents.append(extent) | ||
| extents = new_extents | ||
| print("after extents", extents) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Remove debug print statement.
Debug print statements should not be left in production code. Remove this line or convert it to proper logging if the information is useful for debugging.
Apply this diff:
- print("after extents", extents)
assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| print("after extents", extents) | |
| assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}" |
🤖 Prompt for AI Agents
In tilelang/language/utils.py around line 67, there's a leftover debug print
call 'print("after extents", extents)'; remove this debug print or replace it
with a proper logger call (e.g., use the module logger.debug with a clear
message and the extents variable) depending on whether the info is needed for
runtime debugging, ensuring imports/logger are present if converting to logging.
| found_ramp: bool = False | ||
| for indice in indices: | ||
| if not isinstance(indice, tir.Ramp): | ||
| return None | ||
| regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) | ||
| return tir.BufferRegion(buffer, regions) | ||
| if isinstance(indice, tir.Ramp): | ||
| regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) | ||
| found_ramp = True | ||
| elif isinstance(indice, tir.PrimExpr): | ||
| regions.append(ir.Range.from_min_extent(indice, 1)) | ||
| else: | ||
| raise ValueError("Unsupported type: ", type(indice)) | ||
| if found_ramp: | ||
| return tir.BufferRegion(buffer, regions) | ||
| else: | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Apply static analysis recommendations for error handling.
The logic enhancement to handle PrimExpr indices is correct. However, static analysis correctly identifies two improvements:
- Use
TypeErrorinstead ofValueErrorfor invalid type errors - Define exception messages outside the raise statement for better maintainability
Apply this diff:
buffer, indices = buffer_load.buffer, buffer_load.indices
regions = []
found_ramp: bool = False
for indice in indices:
if isinstance(indice, tir.Ramp):
regions.append(ir.Range.from_min_extent(indice.base, indice.lanes))
found_ramp = True
elif isinstance(indice, tir.PrimExpr):
regions.append(ir.Range.from_min_extent(indice, 1))
else:
- raise ValueError("Unsupported type: ", type(indice))
+ raise TypeError(f"Unsupported index type: {type(indice)}")
if found_ramp:
return tir.BufferRegion(buffer, regions)
else:
return None📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| found_ramp: bool = False | |
| for indice in indices: | |
| if not isinstance(indice, tir.Ramp): | |
| return None | |
| regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) | |
| return tir.BufferRegion(buffer, regions) | |
| if isinstance(indice, tir.Ramp): | |
| regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) | |
| found_ramp = True | |
| elif isinstance(indice, tir.PrimExpr): | |
| regions.append(ir.Range.from_min_extent(indice, 1)) | |
| else: | |
| raise ValueError("Unsupported type: ", type(indice)) | |
| if found_ramp: | |
| return tir.BufferRegion(buffer, regions) | |
| else: | |
| return None | |
| buffer, indices = buffer_load.buffer, buffer_load.indices | |
| regions = [] | |
| found_ramp: bool = False | |
| for indice in indices: | |
| if isinstance(indice, tir.Ramp): | |
| regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) | |
| found_ramp = True | |
| elif isinstance(indice, tir.PrimExpr): | |
| regions.append(ir.Range.from_min_extent(indice, 1)) | |
| else: | |
| raise TypeError(f"Unsupported index type: {type(indice)}") | |
| if found_ramp: | |
| return tir.BufferRegion(buffer, regions) | |
| else: | |
| return None |
🧰 Tools
🪛 Ruff (0.13.3)
142-142: Prefer TypeError exception for invalid type
(TRY004)
142-142: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In tilelang/utils/language.py around lines 134 to 146, replace the current raise
ValueError(...) with a TypeError and move the exception message into a separate
variable before raising to improve maintainability; construct a clear message
string (e.g., f"Unsupported index type: {type(indice)}") assigned to a variable
and then raise TypeError(message) so the code uses the correct exception type
and keeps the message definied separately from the raise expression.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/warp_specialize/example_warp_specialize_flashmla.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/warp_specialize/example_warp_specialize_flashmla.py (1)
examples/deepseek_mla/example_mla_decode.py (1)
main(283-305)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: format-check
|
|
||
|
|
||
| def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64): | ||
| def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Inconsistency: CLI defaults don't match main() defaults.
The main() function signature now defaults to heads=64 and kv_ctx=1024, but the CLI argument parser (lines 405, 407) still defaults to heads=128 and kv_ctx=8192. This creates unpredictable behavior depending on whether the script is invoked from the command line or called programmatically.
Apply this diff to align the CLI defaults with the updated main() signature:
- parser.add_argument('--heads', type=int, default=128, help='q heads number')
+ parser.add_argument('--heads', type=int, default=64, help='q heads number')- parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
+ parser.add_argument('--kv_ctx', type=int, default=1024, help='kv context length')Additionally, verify that similar files in this PR (e.g., example_mla_decode.py) have consistent defaults:
🏁 Script executed:
#!/bin/bash
# Description: Check for inconsistent defaults across flash attention examples
# Search for main function signatures and CLI parser defaults
echo "=== main() function signatures ==="
rg -n "def main\(.*heads.*kv_ctx" examples/
echo -e "\n=== CLI parser defaults for heads ==="
rg -n "add_argument.*--heads.*default=" examples/
echo -e "\n=== CLI parser defaults for kv_ctx ==="
rg -n "add_argument.*--kv_ctx.*default=" examples/Length of output: 5628
Align CLI argument defaults with updated main() signature.
The parser defaults still use the old values (128, 8192), overriding the function signature (64, 1024). Update them as follows:
examples/warp_specialize/example_warp_specialize_flashmla.py
- parser.add_argument('--heads', type=int, default=128, help='q heads number')
+ parser.add_argument('--heads', type=int, default=64, help='q heads number')
- parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
+ parser.add_argument('--kv_ctx', type=int, default=1024, help='kv context length')📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64): | |
| parser.add_argument('--heads', type=int, default=64, help='q heads number') | |
| parser.add_argument('--kv_ctx', type=int, default=1024, help='kv context length') |
🤖 Prompt for AI Agents
In examples/warp_specialize/example_warp_specialize_flashmla.py around line 385,
the CLI parser still defaults --heads and --kv-ctx to the old values (128 and
8192) which override the new main() signature; update the parser defaults to
match the function signature by setting heads default to 64 and kv_ctx default
to 1024 (also ensure any related defaults like kv_heads, dim, pe_dim, batch
match the main() defaults if present).
as title,
In the previous implementation, the example used
T.atomic_addwith the bfloat16 dtype, which is very slow on CUDA due to its simulated implementation. We improved this and by switching to float32 atomics and introducing a split-sum template, addressing issue #917.Summary by CodeRabbit
New Features
Refactor
Chores