-
Notifications
You must be signed in to change notification settings - Fork 333
[Feat] Support warp reduce #1316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdds five warp-level reduction intrinsics (warp_reduce_sum, warp_reduce_max, warp_reduce_min, warp_reduce_bitand, warp_reduce_bitor) across C++ op registration, CUDA codegen, CUDA device templates, and the Python tilelang API. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python API
participant TL as TL Intrinsic (Op)
participant CG as CodeGenCUDA
participant Dev as CUDA device template
participant GPU as GPU warp lanes
Note over Py,TL: high-level flow for a warp reduction intrinsic
Py->>TL: construct CallNode for tl.warp_reduce_*
TL->>CG: lowering of CallNode
CG->>Dev: emit call to tl::warp_reduce_* in generated CUDA
Dev->>GPU: perform xor-shuffle reduction across lanes
GPU-->>Dev: reduced lane result
Dev-->>CG: emitted expression/value
CG-->>TL: lowered result
TL-->>Py: value used in Python-level expression
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20–30 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
src/target/codegen_cuda.cc (1)
2612-2621: CUDA mapping fortl.warp_reduce_*intrinsics looks consistentThe new cases correctly lower the TIR intrinsics to
tl::warp_reduce_*device helpers and match the one-argument registration inbuiltin.cc. If you want additional safety, you couldICHECK_EQ(op->args.size(), 1U);in each branch, but it’s not strictly necessary givenset_num_inputs(1).tilelang/language/__init__.py (1)
56-73: Warp-reduce helpers are correctly re-exportedThe additional imports from
.reducecleanly expose the newwarp_reduce_*helpers at thetilelang.languagelevel and match the implementations inreduce.py. Ruff’s RUF100 about# noqa: F401is just a config mismatch; if you care about it, you could switch these to bare# noqa(or enable F401 in Ruff), but keeping them as-is is consistent with the rest of this module.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
src/op/builtin.cc(1 hunks)src/op/builtin.h(1 hunks)src/target/codegen_cuda.cc(1 hunks)src/tl_templates/cuda/reduce.h(1 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/reduce.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
src/target/codegen_cuda.cc (1)
tilelang/language/reduce.py (5)
warp_reduce_sum(330-347)warp_reduce_max(350-367)warp_reduce_min(370-387)warp_reduce_bitand(390-407)warp_reduce_bitor(410-427)
src/tl_templates/cuda/reduce.h (1)
tilelang/language/reduce.py (5)
warp_reduce_sum(330-347)warp_reduce_max(350-367)warp_reduce_min(370-387)warp_reduce_bitand(390-407)warp_reduce_bitor(410-427)
src/op/builtin.h (1)
tilelang/language/reduce.py (5)
warp_reduce_sum(330-347)warp_reduce_max(350-367)warp_reduce_min(370-387)warp_reduce_bitand(390-407)warp_reduce_bitor(410-427)
tilelang/language/__init__.py (1)
tilelang/language/reduce.py (5)
warp_reduce_sum(330-347)warp_reduce_max(350-367)warp_reduce_min(370-387)warp_reduce_bitand(390-407)warp_reduce_bitor(410-427)
tilelang/language/reduce.py (1)
tilelang/language/tir/op.py (1)
call_intrin(120-145)
src/op/builtin.cc (1)
tilelang/language/reduce.py (5)
warp_reduce_sum(330-347)warp_reduce_max(350-367)warp_reduce_min(370-387)warp_reduce_bitand(390-407)warp_reduce_bitor(410-427)
🪛 Ruff (0.14.5)
tilelang/language/__init__.py
68-68: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
69-69: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
70-70: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
71-71: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
72-72: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (3)
src/op/builtin.cc (1)
344-368: Builtin registration fortl.warp_reduce_*is coherentThe five new intrinsics are registered consistently with the existing TL ops: correct naming, one input each, and
kOpaqueeffect kind, which is a reasonable choice for warp-synchronous operations.src/op/builtin.h (1)
574-598: Header declarations forwarp_reduce_*match registrations and usageThe five new
TVM_DLLdeclarations are correctly named, documented, and aligned with the corresponding definitions inbuiltin.ccand the uses in CUDA codegen and Python bindings.src/tl_templates/cuda/reduce.h (1)
254-288: Warp-level reduction helpers are well-integratedThe generic
warp_reduceand the five specializedwarp_reduce_*wrappers reuse the existing reducer functors and follow the same shuffle/mask conventions as the rest of this header. This gives a clear, composable warp-reduction primitive for the CUDA backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/tl_templates/cuda/reduce.h (1)
253-262: Usetl::shfl_xor_syncfor consistency with the rest of the file.The warp_reduce implementation uses
__shfl_xor_syncdirectly, but the rest of this file consistently usestl::shfl_xor_sync(lines 75, 95, 207, 217, 233, 243). The butterfly reduction pattern is correct, but maintaining consistency with the namespace prefix improves code uniformity.Apply this diff:
template <typename T, typename ReduceOp> TL_DEVICE T warp_reduce(T value, ReduceOp op) { constexpr uint32_t mask = 0xffffffff; - value = op(value, __shfl_xor_sync(mask, value, 16)); - value = op(value, __shfl_xor_sync(mask, value, 8)); - value = op(value, __shfl_xor_sync(mask, value, 4)); - value = op(value, __shfl_xor_sync(mask, value, 2)); - value = op(value, __shfl_xor_sync(mask, value, 1)); + value = op(value, tl::shfl_xor_sync(mask, value, 16)); + value = op(value, tl::shfl_xor_sync(mask, value, 8)); + value = op(value, tl::shfl_xor_sync(mask, value, 4)); + value = op(value, tl::shfl_xor_sync(mask, value, 2)); + value = op(value, tl::shfl_xor_sync(mask, value, 1)); return value; }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/op/builtin.cc(1 hunks)src/op/builtin.h(1 hunks)src/tl_templates/cuda/reduce.h(1 hunks)tilelang/language/reduce.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/op/builtin.h
🧰 Additional context used
🧬 Code graph analysis (3)
src/tl_templates/cuda/reduce.h (1)
tilelang/language/reduce.py (5)
warp_reduce_sum(330-343)warp_reduce_max(346-359)warp_reduce_min(362-375)warp_reduce_bitand(378-391)warp_reduce_bitor(394-407)
tilelang/language/reduce.py (1)
tilelang/language/tir/op.py (1)
call_intrin(120-145)
src/op/builtin.cc (1)
tilelang/language/reduce.py (5)
warp_reduce_sum(330-343)warp_reduce_max(346-359)warp_reduce_min(362-375)warp_reduce_bitand(378-391)warp_reduce_bitor(394-407)
🔇 Additional comments (3)
src/tl_templates/cuda/reduce.h (1)
264-282: LGTM!The five warp reduction wrappers are well-structured and correctly delegate to the generic
warp_reducetemplate with the appropriate reduction operators.src/op/builtin.cc (1)
344-367: LGTM!The five warp reduction intrinsics are correctly registered following the established pattern. Using
CallEffectKind::kOpaqueis appropriate since these operations involve warp-level thread synchronization via shuffle instructions.tilelang/language/reduce.py (1)
330-407: LGTM!All five warp reduction functions are consistently implemented with clear docstrings. The parameter naming is correct throughout (the previous docstring inconsistency has been addressed). Each function properly delegates to the corresponding registered TL intrinsic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
testing/python/language/test_tilelang_language_warp_reduce.py (4)
1-32: Kernel factory correctly wires warp reductions; confirm closure handling and warp-size assumptionThe kernel setup looks good: each thread loads
x[tx], applies the selectedT.warp_reduce_*op, and writes the warp-wide result back, with an upfront assert guarding unsupportedreduce_opvalues. Two subtle points to double-check:
- Using the Python variable
reduce_opinside@T.prim_funcassumes the TileLang/TIR front-end correctly captures closure constants and resolves theifat compile time; if closures aren’t supported here, this could fail during script parsing or try to generate an invalid string comparison on device.- The kernel is hard-coded to 32 threads and a length-32 tensor, effectively assuming a 32-lane warp; that’s fine for current CUDA targets, but you may need to revisit this if you want the same test to cover backends with different warp widths.
35-41: Sum test is sound; consider determinism and extra dtype coverageThis is a clean end-to-end check for
warp_reduce_sum, andtorch.testing.assert_closeis appropriate for floating-point sums. If you want to tighten things up, you could seed the RNG (e.g., viatorch.manual_seedor any existingtilelang.testingutility) for deterministic inputs, and optionally add a second case for another supported dtype (e.g.,float16orint32) once the intrinsic is confirmed to support it.
43-49: Remove debug print fromtest_warp_reduce_maxPrinting
kernel.get_kernel_source()on every test run will clutter CI logs and slow down large suites. Unless you explicitly rely on this output in automation, it’s better to drop or gate it behind a debug flag.You can simplify the test with:
- print(kernel.get_kernel_source())
60-68: Bitwise-AND test is correct; consider making the reference more explicitThe sequential reduction with
&over the CUDAint32tensor is a valid reference and matches the intended warp behavior. If you ever want to decouple the reference from device semantics, you could compute it on CPU via.item()and plain Python ints, but for bitwise ops the current approach is already precise.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/language/test_tilelang_language_warp_reduce.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_warp_reduce.py (2)
tilelang/language/allocate.py (1)
alloc_local(45-56)tilelang/language/reduce.py (5)
warp_reduce_sum(330-343)warp_reduce_max(346-359)warp_reduce_min(362-375)warp_reduce_bitand(378-391)warp_reduce_bitor(394-407)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (3)
testing/python/language/test_tilelang_language_warp_reduce.py (3)
52-57: Min test mirrors sum/max and looks correctThe structure matches the sum/max tests and should effectively validate
warp_reduce_min(including that all lanes see the same reduced value). No issues from a correctness perspective.
71-79: Bitwise-OR test matches the AND test pattern and looks goodSame pattern as the bitwise-AND test; it should reliably catch regressions in
warp_reduce_bitorand ensure all lanes receive the same reduced value. No changes needed.
82-83: Main guard integration withtilelang.testingThe
if __name__ == "__main__": tilelang.testing.main()guard is a nice touch for running this file directly, and shouldn’t interfere with normal test discovery.
as titiled
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.