Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Nov 22, 2025

as titiled

Summary by CodeRabbit

  • New Features

    • Added five warp-level reduction ops: warp_reduce_sum, warp_reduce_max, warp_reduce_min, warp_reduce_bitand, warp_reduce_bitor.
    • Exposed these ops in the public Python API for use in kernels.
    • CUDA code generation now emits warp-level reduction calls for these ops.
  • Tests

    • New unit tests validating sum, max, min, bitwise-AND, and bitwise-OR warp reductions on CUDA.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 22, 2025

Walkthrough

Adds five warp-level reduction intrinsics (warp_reduce_sum, warp_reduce_max, warp_reduce_min, warp_reduce_bitand, warp_reduce_bitor) across C++ op registration, CUDA codegen, CUDA device templates, and the Python tilelang API.

Changes

Cohort / File(s) Summary
Op declaration & registration
src/op/builtin.h, src/op/builtin.cc
Added five new public Op accessors and registration entries: tl.warp_reduce_sum, tl.warp_reduce_max, tl.warp_reduce_min, tl.warp_reduce_bitand, tl.warp_reduce_bitor. Each is defined with a single input and marked opaque via TCallEffectKind.
CUDA codegen
src/target/codegen_cuda.cc
Extended CallNode handling to emit calls to tl::warp_reduce_* functions for the new TL intrinsics; falls back to base CodeGenC when unmatched.
CUDA device templates
src/tl_templates/cuda/reduce.h
Added a generic warp_reduce<T, ReduceOp> using xor-based shuffles and five specialized wrappers: warp_reduce_sum, warp_reduce_max, warp_reduce_min, warp_reduce_bitand, warp_reduce_bitor (TL_DEVICE).
Python API
tilelang/language/reduce.py, tilelang/language/__init__.py
Added Python helpers warp_reduce_sum, warp_reduce_max, warp_reduce_min, warp_reduce_bitand, warp_reduce_bitor (accepting tir.PrimExpr) and exported them from the package.
Tests
testing/python/language/test_tilelang_language_warp_reduce.py
New tests exercising all five warp reductions via a Torch CUDA kernel, including correctness checks against reference reductions.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python API
    participant TL as TL Intrinsic (Op)
    participant CG as CodeGenCUDA
    participant Dev as CUDA device template
    participant GPU as GPU warp lanes

    Note over Py,TL: high-level flow for a warp reduction intrinsic
    Py->>TL: construct CallNode for tl.warp_reduce_*
    TL->>CG: lowering of CallNode
    CG->>Dev: emit call to tl::warp_reduce_* in generated CUDA
    Dev->>GPU: perform xor-shuffle reduction across lanes
    GPU-->>Dev: reduced lane result
    Dev-->>CG: emitted expression/value
    CG-->>TL: lowered result
    TL-->>Py: value used in Python-level expression
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20–30 minutes

  • Pay attention to:
    • src/tl_templates/cuda/reduce.h — correctness of shuffle offsets, initial value/op identity, and type handling.
    • src/target/codegen_cuda.cc — matching intrinsic names/arity and generated call syntax.
    • src/op/builtin.* — consistent registration metadata and effect kind.
    • Tests — verify kernel correctness and edge cases for bitwise ops.

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰
I hop through lanes of CUDA light,
XOR whispers through the night,
Sum and max, min and bit,
I nudge each lane to share a fit,
Rabbit cheers — reductions bright! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 46.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Feat] Support warp reduce' directly and clearly describes the main change: adding support for warp reduction operations across five new reduction functions (sum, max, min, bitand, bitor).
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
src/target/codegen_cuda.cc (1)

2612-2621: CUDA mapping for tl.warp_reduce_* intrinsics looks consistent

The new cases correctly lower the TIR intrinsics to tl::warp_reduce_* device helpers and match the one-argument registration in builtin.cc. If you want additional safety, you could ICHECK_EQ(op->args.size(), 1U); in each branch, but it’s not strictly necessary given set_num_inputs(1).

tilelang/language/__init__.py (1)

56-73: Warp-reduce helpers are correctly re-exported

The additional imports from .reduce cleanly expose the new warp_reduce_* helpers at the tilelang.language level and match the implementations in reduce.py. Ruff’s RUF100 about # noqa: F401 is just a config mismatch; if you care about it, you could switch these to bare # noqa (or enable F401 in Ruff), but keeping them as-is is consistent with the rest of this module.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 470eb74 and 86658ec.

📒 Files selected for processing (6)
  • src/op/builtin.cc (1 hunks)
  • src/op/builtin.h (1 hunks)
  • src/target/codegen_cuda.cc (1 hunks)
  • src/tl_templates/cuda/reduce.h (1 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/reduce.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
src/target/codegen_cuda.cc (1)
tilelang/language/reduce.py (5)
  • warp_reduce_sum (330-347)
  • warp_reduce_max (350-367)
  • warp_reduce_min (370-387)
  • warp_reduce_bitand (390-407)
  • warp_reduce_bitor (410-427)
src/tl_templates/cuda/reduce.h (1)
tilelang/language/reduce.py (5)
  • warp_reduce_sum (330-347)
  • warp_reduce_max (350-367)
  • warp_reduce_min (370-387)
  • warp_reduce_bitand (390-407)
  • warp_reduce_bitor (410-427)
src/op/builtin.h (1)
tilelang/language/reduce.py (5)
  • warp_reduce_sum (330-347)
  • warp_reduce_max (350-367)
  • warp_reduce_min (370-387)
  • warp_reduce_bitand (390-407)
  • warp_reduce_bitor (410-427)
tilelang/language/__init__.py (1)
tilelang/language/reduce.py (5)
  • warp_reduce_sum (330-347)
  • warp_reduce_max (350-367)
  • warp_reduce_min (370-387)
  • warp_reduce_bitand (390-407)
  • warp_reduce_bitor (410-427)
tilelang/language/reduce.py (1)
tilelang/language/tir/op.py (1)
  • call_intrin (120-145)
src/op/builtin.cc (1)
tilelang/language/reduce.py (5)
  • warp_reduce_sum (330-347)
  • warp_reduce_max (350-367)
  • warp_reduce_min (370-387)
  • warp_reduce_bitand (390-407)
  • warp_reduce_bitor (410-427)
🪛 Ruff (0.14.5)
tilelang/language/__init__.py

68-68: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


69-69: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


70-70: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


71-71: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


72-72: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (3)
src/op/builtin.cc (1)

344-368: Builtin registration for tl.warp_reduce_* is coherent

The five new intrinsics are registered consistently with the existing TL ops: correct naming, one input each, and kOpaque effect kind, which is a reasonable choice for warp-synchronous operations.

src/op/builtin.h (1)

574-598: Header declarations for warp_reduce_* match registrations and usage

The five new TVM_DLL declarations are correctly named, documented, and aligned with the corresponding definitions in builtin.cc and the uses in CUDA codegen and Python bindings.

src/tl_templates/cuda/reduce.h (1)

254-288: Warp-level reduction helpers are well-integrated

The generic warp_reduce and the five specialized warp_reduce_* wrappers reuse the existing reducer functors and follow the same shuffle/mask conventions as the rest of this header. This gives a clear, composable warp-reduction primitive for the CUDA backend.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/tl_templates/cuda/reduce.h (1)

253-262: Use tl::shfl_xor_sync for consistency with the rest of the file.

The warp_reduce implementation uses __shfl_xor_sync directly, but the rest of this file consistently uses tl::shfl_xor_sync (lines 75, 95, 207, 217, 233, 243). The butterfly reduction pattern is correct, but maintaining consistency with the namespace prefix improves code uniformity.

Apply this diff:

 template <typename T, typename ReduceOp>
 TL_DEVICE T warp_reduce(T value, ReduceOp op) {
   constexpr uint32_t mask = 0xffffffff;
-  value = op(value, __shfl_xor_sync(mask, value, 16));
-  value = op(value, __shfl_xor_sync(mask, value, 8));
-  value = op(value, __shfl_xor_sync(mask, value, 4));
-  value = op(value, __shfl_xor_sync(mask, value, 2));
-  value = op(value, __shfl_xor_sync(mask, value, 1));
+  value = op(value, tl::shfl_xor_sync(mask, value, 16));
+  value = op(value, tl::shfl_xor_sync(mask, value, 8));
+  value = op(value, tl::shfl_xor_sync(mask, value, 4));
+  value = op(value, tl::shfl_xor_sync(mask, value, 2));
+  value = op(value, tl::shfl_xor_sync(mask, value, 1));
   return value;
 }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 86658ec and bed2858.

📒 Files selected for processing (4)
  • src/op/builtin.cc (1 hunks)
  • src/op/builtin.h (1 hunks)
  • src/tl_templates/cuda/reduce.h (1 hunks)
  • tilelang/language/reduce.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/op/builtin.h
🧰 Additional context used
🧬 Code graph analysis (3)
src/tl_templates/cuda/reduce.h (1)
tilelang/language/reduce.py (5)
  • warp_reduce_sum (330-343)
  • warp_reduce_max (346-359)
  • warp_reduce_min (362-375)
  • warp_reduce_bitand (378-391)
  • warp_reduce_bitor (394-407)
tilelang/language/reduce.py (1)
tilelang/language/tir/op.py (1)
  • call_intrin (120-145)
src/op/builtin.cc (1)
tilelang/language/reduce.py (5)
  • warp_reduce_sum (330-343)
  • warp_reduce_max (346-359)
  • warp_reduce_min (362-375)
  • warp_reduce_bitand (378-391)
  • warp_reduce_bitor (394-407)
🔇 Additional comments (3)
src/tl_templates/cuda/reduce.h (1)

264-282: LGTM!

The five warp reduction wrappers are well-structured and correctly delegate to the generic warp_reduce template with the appropriate reduction operators.

src/op/builtin.cc (1)

344-367: LGTM!

The five warp reduction intrinsics are correctly registered following the established pattern. Using CallEffectKind::kOpaque is appropriate since these operations involve warp-level thread synchronization via shuffle instructions.

tilelang/language/reduce.py (1)

330-407: LGTM!

All five warp reduction functions are consistently implemented with clear docstrings. The parameter naming is correct throughout (the previous docstring inconsistency has been addressed). Each function properly delegates to the corresponding registered TL intrinsic.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
testing/python/language/test_tilelang_language_warp_reduce.py (4)

1-32: Kernel factory correctly wires warp reductions; confirm closure handling and warp-size assumption

The kernel setup looks good: each thread loads x[tx], applies the selected T.warp_reduce_* op, and writes the warp-wide result back, with an upfront assert guarding unsupported reduce_op values. Two subtle points to double-check:

  1. Using the Python variable reduce_op inside @T.prim_func assumes the TileLang/TIR front-end correctly captures closure constants and resolves the if at compile time; if closures aren’t supported here, this could fail during script parsing or try to generate an invalid string comparison on device.
  2. The kernel is hard-coded to 32 threads and a length-32 tensor, effectively assuming a 32-lane warp; that’s fine for current CUDA targets, but you may need to revisit this if you want the same test to cover backends with different warp widths.

35-41: Sum test is sound; consider determinism and extra dtype coverage

This is a clean end-to-end check for warp_reduce_sum, and torch.testing.assert_close is appropriate for floating-point sums. If you want to tighten things up, you could seed the RNG (e.g., via torch.manual_seed or any existing tilelang.testing utility) for deterministic inputs, and optionally add a second case for another supported dtype (e.g., float16 or int32) once the intrinsic is confirmed to support it.


43-49: Remove debug print from test_warp_reduce_max

Printing kernel.get_kernel_source() on every test run will clutter CI logs and slow down large suites. Unless you explicitly rely on this output in automation, it’s better to drop or gate it behind a debug flag.

You can simplify the test with:

-    print(kernel.get_kernel_source())

60-68: Bitwise-AND test is correct; consider making the reference more explicit

The sequential reduction with & over the CUDA int32 tensor is a valid reference and matches the intended warp behavior. If you ever want to decouple the reference from device semantics, you could compute it on CPU via .item() and plain Python ints, but for bitwise ops the current approach is already precise.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bed2858 and 5dc3cf3.

📒 Files selected for processing (1)
  • testing/python/language/test_tilelang_language_warp_reduce.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_warp_reduce.py (2)
tilelang/language/allocate.py (1)
  • alloc_local (45-56)
tilelang/language/reduce.py (5)
  • warp_reduce_sum (330-343)
  • warp_reduce_max (346-359)
  • warp_reduce_min (362-375)
  • warp_reduce_bitand (378-391)
  • warp_reduce_bitor (394-407)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (3)
testing/python/language/test_tilelang_language_warp_reduce.py (3)

52-57: Min test mirrors sum/max and looks correct

The structure matches the sum/max tests and should effectively validate warp_reduce_min (including that all lanes see the same reduced value). No issues from a correctness perspective.


71-79: Bitwise-OR test matches the AND test pattern and looks good

Same pattern as the bitwise-AND test; it should reliably catch regressions in warp_reduce_bitor and ensure all lanes receive the same reduced value. No changes needed.


82-83: Main guard integration with tilelang.testing

The if __name__ == "__main__": tilelang.testing.main() guard is a nice touch for running this file directly, and shouldn’t interfere with normal test discovery.

@LeiWang1999 LeiWang1999 merged commit caa6dd3 into tile-ai:main Nov 24, 2025
13 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants