Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Oct 5, 2025

as title,
In the previous implementation, the example used T.atomic_add with the bfloat16 dtype, which is very slow on CUDA due to its simulated implementation. We improved this and by switching to float32 atomics and introducing a split-sum template, addressing issue #917.

▶ python examples/flash_attention/example_gqa_bwd.py --use_atomic False                  
All checks passed.✅
torch: 2.62 ms
torch: 170.29 TFlops
tilelang: 2.15 ms
tilelang: 208.09 TFlops
[lei_py312] (gpt6.0.0.preview.hopper) 

▶ python examples/flash_attention/example_gqa_bwd.py --use_split       
2025-10-05 18:15:33  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `flash_bwd` with `out_idx=None`
2025-10-05 18:15:57  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `flash_bwd`
All checks passed.✅
torch: 2.62 ms
torch: 170.57 TFlops
tilelang: 2.06 ms
tilelang: 216.78 TFlops

Summary by CodeRabbit

  • New Features

    • Added an atomic-add backward kernel for attention plus a runtime toggle (use_atomic) to choose atomic or split paths; atomic is default. CLI flags --use_atomic / --use_split added.
  • Refactor

    • Unified tile-region utilities; copy() accepts BufferRegion; removed legacy region helper wrappers and consolidated region handling across Buffer/BufferLoad/BufferRegion.
    • Broader index support when deriving buffer regions from loads; stricter error reporting on unsupported index types.
  • Chores

    • Improved GEMM diagnostic messages.
    • Adjusted example defaults for heads and kv_ctx.

@LeiWang1999
Copy link
Member Author

cc @Rachmanino

@github-actions
Copy link

github-actions bot commented Oct 5, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 5, 2025

Walkthrough

Adds an atomic-add backward kernel and retains a split backward variant across multiple FlashAttention examples, exposing a new use_atomic flag wired through autograd, main, and CLI. Refactors TileLang region helpers into utils with broader Buffer/BufferLoad/BufferRegion support and improves GEMM ICHECK diagnostics.

Changes

Cohort / File(s) Summary of changes
FlashAttention examples: atomic + split backward paths & CLI
examples/flash_attention/example_gqa_bwd.py, examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py, examples/flash_attention/example_mha_bwd.py, examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
Added flashattn_bwd_atomic_add(...) (atomic-add backward kernel) and retained/adjusted split backward variants; extended _attention.forward(..., use_atomic=True) to store ctx.use_atomic and select kernels in backward; propagated use_atomic through main, attention calls, and CLI via --use_atomic/--use_split; adjusted thread/block defaults, dtype conversions, and post-processing.
TileLang region helpers refactor
tilelang/language/atomic.py, tilelang/language/copy.py, tilelang/language/customize.py, tilelang/language/utils.py, tilelang/utils/language.py
Centralized region/convert helpers in tilelang/language/utils.py; updated imports and type usage to tvm.tir types; atomic.py and copy.py now use shared utilities and support Buffer/BufferRegion/BufferLoad via tile-region conversions; removed old helper functions from customize.py; get_buffer_region_from_load now accepts PrimExpr and raises on unsupported index types.
GEMM diagnostics
src/op/gemm.cc
Enriched ICHECK messages in GemmWarpPolicyNode::ComputeWarpPartition to include m_warp, n_warp, and num_warps, and added an extra post-branch ICHECK to assert m_warp * n_warp == num_warps.
Example defaults
examples/warp_specialize/example_warp_specialize_flashmla.py
Lowered default heads from 128→64 and kv_ctx from 8192→1024 in main signature.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant CLI as CLI
  participant Main as main(...)
  participant AttnFwd as _attention.forward
  participant Autograd as _attention.backward
  participant Atomic as flashattn_bwd_atomic_add
  participant Split as flashattn_bwd_split

  CLI->>Main: parse args (--use_atomic / --use_split)
  Main->>AttnFwd: attention(Q,K,V, causal, groups, use_atomic)
  AttnFwd-->>Main: O (output) and ctx.use_atomic stored

  Note over Autograd: Backward invoked by autograd
  Autograd->>Autograd: if ctx.use_atomic == True
  Autograd-->>Atomic: call atomic-add kernel
  Atomic-->>Autograd: return dq, dk, dv
  Autograd->>Autograd: else
  Autograd-->>Split: call split kernel
  Split-->>Autograd: return dq, dk, dv
  Autograd-->>Main: gradients (dq, dk, dv)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • tzj-fxz

Poem

I twitch my ears at flags that flip,
Atomic hops or split a skip.
Regions moved and checks made clear,
Warps announce their numbers near.
Gradients bound, I race back track—🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly summarizes the key changes by indicating the introduction of a split+sum template and the optimization of atomic_add performance in the backward examples, which directly reflects the primary focus of this pull request.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tilelang/utils/language.py (1)

135-140: Use TypeError and format the message cleanly.

Switching to TypeError (and formatting the message instead of passing a tuple) aligns with the intent of signaling an invalid index type and satisfies TRY004/TRY003.

-        else:
-            raise ValueError("Unsupported type: ", type(indice))
+        else:
+            raise TypeError(f"Unsupported index type: {type(indice)}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b31de0c and 9a4a359.

📒 Files selected for processing (10)
  • examples/flash_attention/example_gqa_bwd.py (9 hunks)
  • examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (9 hunks)
  • examples/flash_attention/example_mha_bwd.py (9 hunks)
  • examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (9 hunks)
  • src/op/gemm.cc (2 hunks)
  • tilelang/language/atomic.py (2 hunks)
  • tilelang/language/copy.py (1 hunks)
  • tilelang/language/customize.py (1 hunks)
  • tilelang/language/utils.py (1 hunks)
  • tilelang/utils/language.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (8)
tilelang/language/atomic.py (3)
tilelang/language/utils.py (4)
  • buffer_to_tile_region (30-42)
  • buffer_region_to_tile_region (73-90)
  • buffer_load_to_tile_region (45-70)
  • region (8-27)
tilelang/utils/language.py (1)
  • get_buffer_region_from_load (124-141)
tilelang/language/frame.py (2)
  • has_let_value (189-198)
  • get_let_value (201-210)
tilelang/language/utils.py (1)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (1)
examples/flash_attention/example_gqa_bwd.py (7)
  • flashattn_bwd_atomic_add (150-245)
  • flash_bwd (171-243)
  • flash_bwd (274-349)
  • make_dq_layout (116-119)
  • flashattn_bwd_split (251-351)
  • forward (358-368)
  • main (466-524)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (1)
examples/flash_attention/example_gqa_bwd.py (6)
  • flashattn_bwd_atomic_add (150-245)
  • flash_bwd (171-243)
  • flash_bwd (274-349)
  • make_dq_layout (116-119)
  • forward (358-368)
  • flashattn_fwd (12-80)
examples/flash_attention/example_mha_bwd.py (4)
examples/flash_attention/example_gqa_bwd.py (5)
  • flashattn_bwd_atomic_add (150-245)
  • flash_bwd (171-243)
  • flash_bwd (274-349)
  • make_dq_layout (116-119)
  • flashattn_bwd_split (251-351)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (5)
  • flashattn_bwd_atomic_add (149-249)
  • flash_bwd (165-247)
  • flash_bwd (271-356)
  • make_dq_layout (115-118)
  • flashattn_bwd_split (255-358)
tilelang/language/kernel.py (2)
  • threads (195-199)
  • Kernel (209-258)
tilelang/language/atomic.py (1)
  • atomic_add (116-228)
tilelang/language/copy.py (1)
tilelang/language/utils.py (3)
  • buffer_to_tile_region (30-42)
  • buffer_region_to_tile_region (73-90)
  • buffer_load_to_tile_region (45-70)
tilelang/utils/language.py (1)
tilelang/language/ast/ir.py (1)
  • Range (1716-1728)
examples/flash_attention/example_gqa_bwd.py (1)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (9)
  • flashattn_bwd_atomic_add (150-256)
  • flash_bwd (171-254)
  • flash_bwd (285-371)
  • make_dq_layout (116-119)
  • flashattn_bwd_split (262-373)
  • forward (380-390)
  • flashattn_bwd_preprocess (87-113)
  • flashattn_bwd_postprocess (126-144)
  • main (488-546)
🪛 Ruff (0.13.3)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py

515-515: Ambiguous variable name: O

(E741)

examples/flash_attention/example_mha_bwd_wgmma_pipelined.py

452-452: Ambiguous variable name: O

(E741)

examples/flash_attention/example_mha_bwd.py

431-431: Ambiguous variable name: O

(E741)

tilelang/utils/language.py

140-140: Prefer TypeError exception for invalid type

(TRY004)


140-140: Avoid specifying long messages outside the exception class

(TRY003)

examples/flash_attention/example_gqa_bwd.py

493-493: Ambiguous variable name: O

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: format-check
🔇 Additional comments (2)
src/op/gemm.cc (2)

288-290: LGTM! Enhanced diagnostic output.

The error message now includes the actual values of m_warp, n_warp, and num_warps, which will significantly aid debugging when this assertion fails.


374-377: LGTM! Additional validation strengthens robustness.

This new assertion provides an extra safety check for the non-WGMMA code path, complementing the similar check in the WGMMA branch. The enhanced error message will help diagnose issues if the warp partition computation logic produces invalid results.

Comment on lines +253 to +255
for i, j in T.Parallel(block_M, dim_qk):
T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard dK atomics on tail blocks

When seq_len isn’t an exact multiple of block_M, the tail block will execute this loop with by * block_M + i >= seq_len, so we end up atomically writing past the end of dK. Please gate the atomic add the same way we already guard the dQ updates.

-        for i, j in T.Parallel(block_M, dim_qk):
-            T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])
+        for i, j in T.Parallel(block_M, dim_qk):
+            if by * block_M + i < seq_len:
+                T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for i, j in T.Parallel(block_M, dim_qk):
T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])
for i, j in T.Parallel(block_M, dim_qk):
if by * block_M + i < seq_len:
T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py around lines 253
to 255, the atomic_add to dK can write past the end when seq_len is not a
multiple of block_M; wrap the T.atomic_add with the same guard used for dQ
(i.e., check that by * block_M + i < seq_len before performing the atomic_add)
so tail-block iterations that exceed seq_len are skipped.

Comment on lines +243 to +245
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared)

return flash_bwd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bounds-check dK atomic adds

For non-multiple seq_len, by * block_M + i can run past the valid range and we’ll issue atomics outside dK. Please mirror the guard used for dQ so the tail block lands safely.

-        for i, j in T.Parallel(block_M, dim_qk):
-            T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])
+        for i, j in T.Parallel(block_M, dim_qk):
+            if by * block_M + i < seq_len:
+                T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j])

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd.py around lines 243 to 245, the
T.atomic_add into dK can write past the buffer when seq_len is not a multiple of
block_M; mirror the guard used for dQ by bounding the by*block_M:(by+1)*block_M
range so it never exceeds seq_len. Specifically, compute the valid tail length
(e.g., end = min((by+1)*block_M, seq_len)) or apply a per-element
mask/conditional so atomic_add only accumulates into dK[bz, by*block_M:end,
bx//groups, :] (or only perform atomic_add for indices i where by*block_M + i <
seq_len), ensuring the tail block is safely truncated before performing the
atomic adds.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9a4a359 and c2ec489.

📒 Files selected for processing (2)
  • tilelang/language/utils.py (1 hunks)
  • tilelang/utils/language.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/utils.py (1)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
🪛 Ruff (0.13.3)
tilelang/utils/language.py

142-142: Prefer TypeError exception for invalid type

(TRY004)


142-142: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: format-check
🔇 Additional comments (4)
tilelang/language/utils.py (4)

1-6: LGTM!

The imports are appropriate and all imported types are used in the new region utility functions.


8-27: LGTM!

The function correctly maps access types to numeric codes and constructs the region intrinsic call. The docstring clearly documents the behavior including the KeyError for invalid access types.


30-42: LGTM!

The function correctly converts a Buffer to a tile region descriptor by creating a region that spans the entire buffer.


72-89: LGTM!

The function correctly converts a BufferRegion to a tile region descriptor by extracting mins and extents, with appropriate validation.

Comment on lines +59 to +60
# (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, "
# f"region will be expanded in the last 2 dimensions")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Remove commented-out code.

The commented-out code should either be removed or converted to proper logging if the information is needed for debugging.

Apply this diff:

-    if len(indices) > len(extents):
-        # (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, "
-        # f"region will be expanded in the last 2 dimensions")
-        new_extents = []
+    if len(indices) > len(extents):
+        new_extents = []
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, "
# f"region will be expanded in the last 2 dimensions")
if len(indices) > len(extents):
new_extents = []
🤖 Prompt for AI Agents
In tilelang/language/utils.py around lines 59-60, there is commented-out code
that should be removed or turned into proper logging; delete the two commented
lines or replace them with a call to the project's logger (e.g., logger.debug or
logger.warning) that formats the same message with variables (indices, extents,
load) so runtime diagnostics are preserved, ensuring you import/use the module
logger consistently with the codebase.

for extent in extents:
new_extents.append(extent)
extents = new_extents
print("after extents", extents)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Remove debug print statement.

Debug print statements should not be left in production code. Remove this line or convert it to proper logging if the information is useful for debugging.

Apply this diff:

-    print("after extents", extents)
     assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print("after extents", extents)
assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
🤖 Prompt for AI Agents
In tilelang/language/utils.py around line 67, there's a leftover debug print
call 'print("after extents", extents)'; remove this debug print or replace it
with a proper logger call (e.g., use the module logger.debug with a clear
message and the extents variable) depending on whether the info is needed for
runtime debugging, ensuring imports/logger are present if converting to logging.

Comment on lines +134 to +146
found_ramp: bool = False
for indice in indices:
if not isinstance(indice, tir.Ramp):
return None
regions.append(ir.Range.from_min_extent(indice.base, indice.lanes))
return tir.BufferRegion(buffer, regions)
if isinstance(indice, tir.Ramp):
regions.append(ir.Range.from_min_extent(indice.base, indice.lanes))
found_ramp = True
elif isinstance(indice, tir.PrimExpr):
regions.append(ir.Range.from_min_extent(indice, 1))
else:
raise ValueError("Unsupported type: ", type(indice))
if found_ramp:
return tir.BufferRegion(buffer, regions)
else:
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Apply static analysis recommendations for error handling.

The logic enhancement to handle PrimExpr indices is correct. However, static analysis correctly identifies two improvements:

  1. Use TypeError instead of ValueError for invalid type errors
  2. Define exception messages outside the raise statement for better maintainability

Apply this diff:

     buffer, indices = buffer_load.buffer, buffer_load.indices
     regions = []
     found_ramp: bool = False
     for indice in indices:
         if isinstance(indice, tir.Ramp):
             regions.append(ir.Range.from_min_extent(indice.base, indice.lanes))
             found_ramp = True
         elif isinstance(indice, tir.PrimExpr):
             regions.append(ir.Range.from_min_extent(indice, 1))
         else:
-            raise ValueError("Unsupported type: ", type(indice))
+            raise TypeError(f"Unsupported index type: {type(indice)}")
     if found_ramp:
         return tir.BufferRegion(buffer, regions)
     else:
         return None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
found_ramp: bool = False
for indice in indices:
if not isinstance(indice, tir.Ramp):
return None
regions.append(ir.Range.from_min_extent(indice.base, indice.lanes))
return tir.BufferRegion(buffer, regions)
if isinstance(indice, tir.Ramp):
regions.append(ir.Range.from_min_extent(indice.base, indice.lanes))
found_ramp = True
elif isinstance(indice, tir.PrimExpr):
regions.append(ir.Range.from_min_extent(indice, 1))
else:
raise ValueError("Unsupported type: ", type(indice))
if found_ramp:
return tir.BufferRegion(buffer, regions)
else:
return None
buffer, indices = buffer_load.buffer, buffer_load.indices
regions = []
found_ramp: bool = False
for indice in indices:
if isinstance(indice, tir.Ramp):
regions.append(ir.Range.from_min_extent(indice.base, indice.lanes))
found_ramp = True
elif isinstance(indice, tir.PrimExpr):
regions.append(ir.Range.from_min_extent(indice, 1))
else:
raise TypeError(f"Unsupported index type: {type(indice)}")
if found_ramp:
return tir.BufferRegion(buffer, regions)
else:
return None
🧰 Tools
🪛 Ruff (0.13.3)

142-142: Prefer TypeError exception for invalid type

(TRY004)


142-142: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In tilelang/utils/language.py around lines 134 to 146, replace the current raise
ValueError(...) with a TypeError and move the exception message into a separate
variable before raising to improve maintainability; construct a clear message
string (e.g., f"Unsupported index type: {type(indice)}") assigned to a variable
and then raise TypeError(message) so the code uses the correct exception type
and keeps the message definied separately from the raise expression.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c2ec489 and 11a46c8.

📒 Files selected for processing (1)
  • examples/warp_specialize/example_warp_specialize_flashmla.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/warp_specialize/example_warp_specialize_flashmla.py (1)
examples/deepseek_mla/example_mla_decode.py (1)
  • main (283-305)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: format-check



def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Inconsistency: CLI defaults don't match main() defaults.

The main() function signature now defaults to heads=64 and kv_ctx=1024, but the CLI argument parser (lines 405, 407) still defaults to heads=128 and kv_ctx=8192. This creates unpredictable behavior depending on whether the script is invoked from the command line or called programmatically.

Apply this diff to align the CLI defaults with the updated main() signature:

-    parser.add_argument('--heads', type=int, default=128, help='q heads number')
+    parser.add_argument('--heads', type=int, default=64, help='q heads number')
-    parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
+    parser.add_argument('--kv_ctx', type=int, default=1024, help='kv context length')

Additionally, verify that similar files in this PR (e.g., example_mla_decode.py) have consistent defaults:


🏁 Script executed:

#!/bin/bash
# Description: Check for inconsistent defaults across flash attention examples

# Search for main function signatures and CLI parser defaults
echo "=== main() function signatures ==="
rg -n "def main\(.*heads.*kv_ctx" examples/

echo -e "\n=== CLI parser defaults for heads ==="
rg -n "add_argument.*--heads.*default=" examples/

echo -e "\n=== CLI parser defaults for kv_ctx ==="
rg -n "add_argument.*--kv_ctx.*default=" examples/

Length of output: 5628


Align CLI argument defaults with updated main() signature.

The parser defaults still use the old values (128, 8192), overriding the function signature (64, 1024). Update them as follows:

 examples/warp_specialize/example_warp_specialize_flashmla.py
-    parser.add_argument('--heads', type=int, default=128, help='q heads number')
+    parser.add_argument('--heads', type=int, default=64, help='q heads number')
-    parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
+    parser.add_argument('--kv_ctx', type=int, default=1024, help='kv context length')
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
parser.add_argument('--heads', type=int, default=64, help='q heads number')
parser.add_argument('--kv_ctx', type=int, default=1024, help='kv context length')
🤖 Prompt for AI Agents
In examples/warp_specialize/example_warp_specialize_flashmla.py around line 385,
the CLI parser still defaults --heads and --kv-ctx to the old values (128 and
8192) which override the new main() signature; update the parser defaults to
match the function signature by setting heads default to 64 and kv_ctx default
to 1024 (also ensure any related defaults like kv_heads, dim, pe_dim, batch
match the main() defaults if present).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant