-
Notifications
You must be signed in to change notification settings - Fork 289
[Language] Support atomic add with ret #870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Introduced atomic functions including AtomicMax, AtomicMin, AtomicAdd, and their return variants for various data types. - Implemented support for half, bfloat16, and float types with appropriate memory ordering. - Moved atomic-related utilities from common.h to the new atomic.h file for better organization. - Added Python bindings for atomic operations in tilelang, including atomic_max, atomic_min, atomic_add, and their vectorized counterparts. - Updated customize.py to utilize the new atomic functions, enhancing modularity and maintainability.
- Reformatted atomic operation implementations in atomic.h for better code clarity. - Adjusted function signatures in tilelang's atomic.py to enhance readability by aligning parameters. - Cleaned up unnecessary whitespace and comments in customize.py to streamline the codebase.
|
Caution Review failedThe pull request is closed. Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a standalone CUDA atomic header and removes atomic helpers from common.h; introduces a TileLang Python atomic module (re-exported from customize.py); broadens local buffer scope detection; adds a pass flag to opt out of thread-storage synchronization; relaxes AtomicAdd argument/validation paths; adds atomic tests; updates example defaults and a submodule pointer. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant TL_PY as tilelang.language.atomic
participant TilePath as Tile-Region Path
participant Extern as Extern Ops (C++)
User->>TL_PY: atomic_add(dst, value, memory_order?, return_prev?)
alt extents unknown / extern path
TL_PY->>Extern: AtomicAdd / AtomicAddRet(dst, value, mo_id)
Extern-->>TL_PY: (prev/new) value
else extents known (tile-region path)
TL_PY->>TilePath: Convert Buffer/Region → tile-region add
Note right of TilePath: return_prev unsupported
TilePath-->>TL_PY: side-effect only
end
TL_PY-->>User: Result or expression
sequenceDiagram
autonumber
participant Kernel as CUDA Kernel
participant CUDA_ATOM as src/tl_templates/cuda/atomic.h
participant HW as CUDA Hardware
Kernel->>CUDA_ATOM: AtomicAdd<T1,T2>(addr, val, memory_order)
alt half/bfloat16/vectorized (arch-guarded)
CUDA_ATOM->>CUDA_ATOM: choose specialized path (reinterpret/intrinsic)
else general types
CUDA_ATOM->>CUDA_ATOM: use cuda::atomic_ref with memory_order
end
CUDA_ATOM->>HW: issue atomic operation
HW-->>CUDA_ATOM: optional old value
CUDA_ATOM-->>Kernel: void or previous value
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
Summary of ChangesHello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the atomic operation capabilities within the TileLang framework, primarily by adding support for atomic operations that return the previous value. This functionality is implemented in new CUDA device functions and exposed through a refactored and centralized Python API. The changes improve the expressiveness and utility of atomic operations for complex parallel programming patterns, while also enhancing code organization and ensuring correct memory access handling for local variables. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for atomic operations that return the previous value, such as atomic_add_ret. The changes are well-structured, with the C++ CUDA implementations nicely refactored into a new atomic.h header and the Python bindings into a new atomic.py module.
The core logic for the new ...Ret functions in both C++ and Python seems correct. However, I've found a couple of critical issues in the Python wrappers for vectorized atomics (atomic_addx2 and atomic_addx4) concerning incorrect return type handling when return_prev is set to True. These could lead to incorrect behavior or runtime errors. Please see my detailed comments.
| >>> atomic_addx2(global_grads[i, j:j+2], grads[i, j:j+2]) | ||
| """ | ||
| func_name = "AtomicAddx2Ret" if return_prev else "AtomicAddx2" | ||
| return_type = dst.dtype if return_prev else "handle" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return_type is incorrect when return_prev=True. dst.dtype provides the scalar element type (e.g., "float16"), but the underlying AtomicAddx2Ret C++ function returns a vector type (e.g., half2). The return_type should be a vector type string, which can be constructed by appending "x2" to the scalar type.
| return_type = dst.dtype if return_prev else "handle" | |
| return_type = dst.dtype + "x2" if return_prev else "handle" |
| >>> atomic_addx4(rgba_dst, rgba_add) # Atomic blend of all 4 channels | ||
| """ | ||
| func_name = "AtomicAddx4Ret" if return_prev else "AtomicAddx4" | ||
| return_type = "float4" if "float" in str(dst.dtype).lower() else "handle" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for determining return_type is incorrect for two reasons:
- It does not depend on
return_prev. Whenreturn_prev=False,AtomicAddx4is called which returnsvoid, butreturn_typecould be incorrectly set to"float4". It should be"handle". - When
return_prev=Truefor a non-floatdst.dtype, it returns a"handle"instead of the previous value. Sinceatomic_addx4is only implemented forfloatin C++, this case should probably raise an error instead of silently returning a wrong type.
The suggestion below fixes the first issue. For the second issue, you might consider raising a TypeError.
| return_type = "float4" if "float" in str(dst.dtype).lower() else "handle" | |
| return_type = "float4" if return_prev and "float" in str(dst.dtype).lower() else "handle" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
🧹 Nitpick comments (4)
tilelang/language/customize.py (1)
6-6: Drop unused noqa and explicitly re-export via allRuff flags the
# noqa: F401as unused. Make the re-exports explicit and remove the noqa to keep linters happy.Apply:
-from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 +from .atomic import ( + atomic_max, + atomic_min, + atomic_add, + atomic_addx2, + atomic_addx4, + atomic_load, + atomic_store, +) + +__all__ = [ + "atomic_max", + "atomic_min", + "atomic_add", + "atomic_addx2", + "atomic_addx4", + "atomic_load", + "atomic_store", +]tilelang/language/atomic.py (1)
216-225: Behavior note: return_prev unsupported for tile-region pathThrowing NotImplementedError is fine for now. Consider documenting at call sites to avoid surprises.
- Add a short note in the docstring of atomic_add near the tile-region description indicating return_prev is only supported for scalar/addressed path.
src/tl_templates/cuda/atomic.h (1)
33-41: bfloat16/half intrinsics headers
__float2halfand__float2bfloat16require proper headers. Relying on cuda_runtime.h transitively may be brittle.Include explicit headers (behind the RTC guard as needed):
#ifndef __CUDACC_RTC__ #include <cuda_runtime.h> +#include <cuda_fp16.h> +#if defined(__CUDA_ARCH__) || defined(__CUDA_ARCH_LIST__) +#include <cuda_bfloat16.h> +#endif #endifPlease confirm CI covers both NVCC and NVRTC builds across supported CUDA versions.
src/transform/legalize_safe_memory_access.cc (1)
238-240: Treating "local.var" as local — LGTM; audit other scope branchesIsLocalBuffer change is correct. Repo search shows many direct comparisons to "local" or "local.fragment" that won't match "local.var" — update those checks or use IsLocalBuffer where appropriate.
Files to audit (non‑exhaustive): src/transform/loop_partition.cc, src/transform/layout_inference.cc, src/transform/common/loop_parallel_transform_utils.h, src/op/parallel.cc, src/op/reduce.cc, src/op/gemm.cc, src/op/gemm_sp.cc, src/op/fill.cc, src/op/copy.cc, src/op/atomic_add.cc
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/tl_templates/cuda/atomic.h(1 hunks)src/tl_templates/cuda/common.h(1 hunks)src/transform/legalize_safe_memory_access.cc(1 hunks)tilelang/language/atomic.py(1 hunks)tilelang/language/customize.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/legalize_safe_memory_access.cc (1)
src/transform/storage_rewrite.cc (2)
scope(674-678)scope(674-674)
tilelang/language/atomic.py (3)
tilelang/language/tir/op.py (3)
call_extern(172-194)address_of(463-479)call_intrin(119-144)tilelang/language/frame.py (2)
has_let_value(189-198)get_let_value(201-210)tilelang/language/customize.py (4)
region(9-28)buffer_to_tile_region(31-43)buffer_region_to_tile_region(71-93)buffer_load_to_tile_region(46-68)
tilelang/language/customize.py (2)
src/op/atomic_add.h (1)
tvm(12-84)tilelang/language/atomic.py (7)
atomic_max(20-64)atomic_min(67-113)atomic_add(116-225)atomic_addx2(228-263)atomic_addx4(266-301)atomic_load(304-341)atomic_store(344-395)
🪛 Ruff (0.13.1)
tilelang/language/customize.py
6-6: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: format-check
- GitHub Check: format-check
🔇 Additional comments (3)
tilelang/language/customize.py (1)
4-6: Re-export move LGTMSwitching to re-exporting from tilelang.language.atomic keeps the public API stable and simplifies this module.
src/tl_templates/cuda/atomic.h (1)
23-27: Incorrect — CUDA_ARCH_LIST is an nvcc predefined macronvcc defines CUDA_ARCH_LIST (a comma-separated list of target CUDA_ARCH values); the repo already uses it across src/tl_templates/cuda/*. Keep the checks as-is — only add a fallback to CUDA_ARCH or provide a compatibility define if you need non-nvcc compiler support.
Likely an incorrect or invalid review comment.
tilelang/language/atomic.py (1)
10-17: Memory-order mapping is correct — no change requiredcuda::memory_order uses the C/C++ memory-order enum with integer values 0..5 for relaxed→seq_cst, matching _MEMORY_ORDER_ID_MAP; src/tl_templates/cuda/atomic.h passes ints to cuda::memory_order, so the hard-coded mapping is consistent at runtime.
| #include <cuda/atomic> | ||
| #include <cutlass/numeric_types.h> | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing include for std::is_same_v
You use std::is_same_v but don’t include <type_traits>.
Apply:
#ifndef __CUDACC_RTC__
#include <cuda_runtime.h>
#endif
#include <cuda/atomic>
#include <cutlass/numeric_types.h>
+#include <type_traits>📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #include <cuda/atomic> | |
| #include <cutlass/numeric_types.h> | |
| #include <cuda/atomic> | |
| #include <cutlass/numeric_types.h> | |
| #include <type_traits> |
🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 7 to 9, the file uses
std::is_same_v but doesn't include <type_traits>; add an #include for
<type_traits> near the existing includes (e.g., alongside <cuda/atomic> and
<cutlass/numeric_types.h>) so the template trait is available and the code
compiles.
| #define TL_DEVICE __forceinline__ __device__ | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard TL_DEVICE to avoid redefinition with common.h
atomic.h defines TL_DEVICE, and common.h defines it too. Add a guard here.
-#define TL_DEVICE __forceinline__ __device__
+#ifndef TL_DEVICE
+#define TL_DEVICE __forceinline__ __device__
+#endif📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #define TL_DEVICE __forceinline__ __device__ | |
| #ifndef TL_DEVICE | |
| #define TL_DEVICE __forceinline__ __device__ | |
| #endif |
🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 13-14, TL_DEVICE is
unconditionally defined which conflicts with common.h; wrap the definition in a
macro guard so it is only defined if not already present (use #ifndef TL_DEVICE
... #define TL_DEVICE __forceinline__ __device__ ... #endif) to prevent
redefinition collisions.
src/tl_templates/cuda/atomic.h
Outdated
| TL_DEVICE void AtomicAddx2(half_t *address, half_t *val) { | ||
| atomicAdd(reinterpret_cast<half2 *>(address), | ||
| static_cast<half2>(*reinterpret_cast<half2 *>(val))); | ||
| } | ||
|
|
||
| TL_DEVICE half2 AtomicAddx2Ret(half_t *address, half_t *val) { | ||
| return atomicAdd(reinterpret_cast<half2 *>(address), | ||
| static_cast<half2>(*reinterpret_cast<half2 *>(val))); | ||
| } | ||
|
|
||
| #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) | ||
| TL_DEVICE void AtomicAddx2(bfloat16_t *address, bfloat16_t *val) { | ||
| atomicAdd( | ||
| reinterpret_cast<__nv_bfloat162 *>(address), | ||
| static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); | ||
| } | ||
|
|
||
| TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *address, bfloat16_t *val) { | ||
| return atomicAdd( | ||
| reinterpret_cast<__nv_bfloat162 *>(address), | ||
| static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); | ||
| } | ||
| #endif | ||
|
|
||
| #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) | ||
| TL_DEVICE void AtomicAddx2(float *address, float *val) { | ||
| atomicAdd(reinterpret_cast<float2 *>(address), | ||
| static_cast<float2>(*reinterpret_cast<float2 *>(val))); | ||
| } | ||
|
|
||
| TL_DEVICE float2 AtomicAddx2Ret(float *address, float *val) { | ||
| return atomicAdd(reinterpret_cast<float2 *>(address), | ||
| static_cast<float2>(*reinterpret_cast<float2 *>(val))); | ||
| } | ||
|
|
||
| TL_DEVICE void AtomicAddx4(float *address, float *val) { | ||
| atomicAdd(reinterpret_cast<float4 *>(address), | ||
| static_cast<float4>(*reinterpret_cast<float4 *>(val))); | ||
| } | ||
|
|
||
| TL_DEVICE float4 AtomicAddx4Ret(float *address, float *val) { | ||
| return atomicAdd(reinterpret_cast<float4 *>(address), | ||
| static_cast<float4>(*reinterpret_cast<float4 *>(val))); | ||
| } | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Vector atomic add overloads: alignment/aliasing assumptions
The reinterpret_casts assume address alignment to half2/__nv_bfloat162/float2/float4. If callers pass misaligned pointers (e.g., odd index), behavior is undefined.
- Document alignment requirements.
- Optionally assert alignment in debug builds:
assert(reinterpret_cast<uintptr_t>(address) % alignof(half2) == 0);
src/tl_templates/cuda/atomic.h
Outdated
| template <typename T> TL_DEVICE T AtomicLoad(T *address, int memory_order) { | ||
| cuda::atomic_ref<T, cuda::thread_scope_device> aref(*address); | ||
| return aref.load(cuda::memory_order(memory_order)); | ||
| } | ||
|
|
||
| template <typename T1, typename T2> | ||
| TL_DEVICE void AtomicStore(T1 *address, T2 value, int memory_order) { | ||
| using NT1 = typename normalize_atomic_type<T1>::type; | ||
| cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); | ||
| aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order)); | ||
| } No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
AtomicLoad/Store for half/bfloat16 types
Using cuda::atomic_ref<T> with T = half/__nv_bfloat16 may not be supported across all CUDA versions. You special-case half/bf16 elsewhere; consider mirroring here or documenting minimum CUDA version.
- Add static_asserts or SFINAE constraints for supported types.
- Alternatively, normalize to an integer storage type for load/store and cast.
🏁 Script executed:
#!/bin/bash
# Search for any prior handling patterns for half/bfloat atomics.
rg -nP 'atomic_ref<.*half|bfloat' -C2 srcLength of output: 39589
🏁 Script executed:
#!/bin/bash
set -euo pipefail
sed -n '1,260p' src/tl_templates/cuda/atomic.hLength of output: 6616
Special-case half/__nv_bfloat16 in AtomicLoad/AtomicStore (src/tl_templates/cuda/atomic.h:173-183)
AtomicLoad currently instantiates cuda::atomic_ref with T not normalized; cuda::atomic_ref for half/__nv_bfloat16 is not supported across all CUDA versions — this file already special-cases half/__nv_bfloat16 for other atomics; mirror that here.
- Use normalize_atomic_type and if constexpr to route half/__nv_bfloat16 to the intrinsic-based path (or provide dedicated overloads) instead of cuda::atomic_ref.
- Add static_asserts/SFINAE to prevent instantiation on unsupported CUDA versions or document the minimum CUDA version required.
🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 173 to 183,
AtomicLoad/AtomicStore currently instantiate cuda::atomic_ref with the raw T
which breaks for half and __nv_bfloat16 on some CUDA versions; change these
functions to use normalize_atomic_type to obtain the normalized underlying type
for atomic_ref and add an if constexpr branch (or dedicated overloads) that
detects half/__nv_bfloat16 and routes those to the existing intrinsic-based path
instead of using cuda::atomic_ref; also add a static_assert or SFINAE guard that
prevents instantiation when the CUDA toolchain lacks support (or document the
minimum CUDA version) so consumers get a clear compile-time error.
| #endif | ||
|
|
||
| #include <cuda/atomic> | ||
| #include "atomic.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Macro redefinition risk: TL_DEVICE defined in both atomic.h and common.h
atomic.h defines TL_DEVICE, and common.h defines it again (Line 30). This can cause redefinition warnings/errors.
Guard the definition in atomic.h:
-#define TL_DEVICE __forceinline__ __device__
+#ifndef TL_DEVICE
+#define TL_DEVICE __forceinline__ __device__
+#endif(See atomic.h comment for precise patch.)
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h (include at line 7) atomic.h already defines
TL_DEVICE and common.h redeclares it again around line 30, causing redefinition
warnings; fix this by updating atomic.h to guard its TL_DEVICE definition with
an include-guard style conditional (e.g. wrap the TL_DEVICE define in atomic.h
with #ifndef TL_DEVICE / #define TL_DEVICE ... #endif) so repeated includes
don’t redefine it, and remove or stop redeclaring TL_DEVICE in common.h if
present.
| assert src_extent or dst_extent, "Can't deduce atomicadd extents from args" | ||
| src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) | ||
| dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) | ||
| extent = max(src_extent, dst_extent) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: incorrect extent computation (uses Python list max, not element-wise)
extent = max(src_extent, dst_extent) performs lexicographic list comparison, not per-dimension max or selection. This can produce wrong shapes and mis-legalized tile regions.
Minimal and safe fix: prefer the deduced dst extent when available, else src extent.
- assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
- src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
- dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
- extent = max(src_extent, dst_extent)
+ assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
+ src_extent = list(src_extent) if src_extent is not None else [1] * len(dst_extent)
+ dst_extent = list(dst_extent) if dst_extent is not None else [1] * len(src_extent)
+ # Prefer destination extents when both are present; otherwise, take the one we have.
+ extent = dst_extent if dst_extent is not None else src_extent📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert src_extent or dst_extent, "Can't deduce atomicadd extents from args" | |
| src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) | |
| dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) | |
| extent = max(src_extent, dst_extent) | |
| assert src_extent or dst_extent, "Can't deduce atomicadd extents from args" | |
| src_extent = list(src_extent) if src_extent is not None else [1] * len(dst_extent) | |
| dst_extent = list(dst_extent) if dst_extent is not None else [1] * len(src_extent) | |
| # Prefer destination extents when both are present; otherwise, take the one we have. | |
| extent = dst_extent if dst_extent is not None else src_extent |
🤖 Prompt for AI Agents
In tilelang/language/atomic.py around lines 199 to 203, the code uses extent =
max(src_extent, dst_extent) which does a lexicographic list comparison instead
of choosing the intended per-dimension or preferred extent; replace that with a
deterministic selection: after normalizing src_extent and dst_extent to lists,
set extent to dst_extent if it was provided (non-empty), otherwise set extent to
src_extent. Ensure extent is a list (already normalized above) and remove the
use of Python's max on lists.
- Introduced a new configuration option `tl.disable_thread_storage_sync` to control the automatic insertion of thread synchronization barriers in shared memory access. - Updated the `ThreadSync` pass to check this configuration and bypass synchronization if disabled. - Enhanced documentation in `builtin.h` and `pass_config.py` to clarify the purpose and usage of the new option.
- Simplified the retrieval of the thread storage sync configuration in the `ThreadSync` pass by removing unnecessary intermediate variables. - Ensured that the inclusion of `builtin.h` is consistent by moving it to the appropriate location in the file.
- Updated atomic operations in CUDA templates to remove unnecessary address_of calls, enhancing performance and readability. - Refactored atomic operation signatures in tilelang's atomic.py to accept references instead of pointers. - Added new atomic operations and corresponding test cases for atomic add, max, min, and load/store functionalities in the testing suite. - Updated the TVM subproject to the latest commit for better compatibility.
- Modified the `heads` parameter in both `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py` and `example_mha_sink_fwd_bhsd_wgmma_pipelined.py` from 1 to 32 to enhance performance in attention mechanisms. - Ensured consistency across example scripts for improved usability and testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
357-361: Align CLI defaults with the updated function defaults.
main()now defaults to batch=1 and seq lengths of 256, yet the argparse defaults (see Lines 423‑427) still advertise the old 8/4096 configuration. Unless the mismatch is intentional, please update the parser defaults—or add a brief note—so CLI users aren’t surprised by running a much heavier setup than the in-module defaults imply.examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1)
367-372: Keep the CLI defaults in sync with the new signature defaults.
Withmain()now defaulting to heads=32 and seq_q/seq_kv=256, the argparse defaults at Lines 438‑444 still point to the previous 64/4096 values. Please bring the parser defaults in line (or document the difference) so command-line runs match the lighter configuration you just introduced.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py(1 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py(1 hunks)src/op/builtin.cc(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
main(133-189)main(367-435)examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
main(121-171)main(232-284)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: format-check
- GitHub Check: bot-task
🔇 Additional comments (1)
src/op/builtin.cc (1)
19-33: New pass config flag registration looks good.The Bool option matches the new constant and downstream usage, keeping pass behavior opt-in without regressions.
- Simplified the extraction of buffer loads for atomic add operations by removing unnecessary address_of calls, improving code clarity and performance. - Updated the data type retrieval for vectorization size calculation to directly access the buffer load node, enhancing efficiency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/transform/atomicadd_vectorize.cc (2)
53-75: Off‑by‑one args.size() check and lack of backward‑compat parsing (BufferLoad vs address_of) can crash or miss vectorization.
- node->args.size() is checked for >= 2 but args[2] is accessed → out‑of‑bounds when only 2 args exist.
- Only direct BufferLoad is handled; older IR that wrapped arguments in address_of(BufferLoad) will be ignored, reducing vectorization opportunities.
Fix both by requiring >= 3 and extracting BufferLoad from either direct or address_of forms.
Apply this diff within this block:
- if (node->op == builtin::call_extern() && node->args.size() >= 2) { + if (node->op == builtin::call_extern() && node->args.size() >= 3) { if (const auto *func_name = node->args[0].as<StringImmNode>()) { if (func_name->value == "AtomicAdd") { - const BufferLoadNode *buffer_load_dst = - node->args[1].as<BufferLoadNode>(); - const BufferLoadNode *buffer_load_src = - node->args[2].as<BufferLoadNode>(); + auto extract_bl = [&](const PrimExpr &arg) -> const BufferLoadNode * { + if (const auto *bl = arg.as<BufferLoadNode>()) return bl; + if (const auto *c = arg.as<CallNode>()) { + if (c->op == builtin::address_of() && c->args.size() == 1) { + return c->args[0].as<BufferLoadNode>(); + } + } + return nullptr; + }; + const BufferLoadNode *buffer_load_dst = extract_bl(node->args[1]); + const BufferLoadNode *buffer_load_src = extract_bl(node->args[2]); if (buffer_load_src && buffer_load_src->buffer.defined() && buffer_load_dst && buffer_load_dst->buffer.defined()) { Buffer dst_buffer = buffer_load_dst->buffer; Array<PrimExpr> indices_dst = buffer_load_dst->indices; UpdateVectorSize(indices_dst, dst_buffer); Buffer src_buffer = buffer_load_src->buffer; Array<PrimExpr> indices_src = buffer_load_src->indices; UpdateVectorSize(indices_src, src_buffer); } } } }
208-222: Same args.size() bug and parsing gap in the rewriter; vectorization silently bails.
- Uses args[2] with only >= 2 guard.
- Only accepts direct BufferLoad; older address_of(BufferLoad) forms won’t match and vectorization won’t apply.
Harden the check and support both forms.
Apply this diff:
- if (node->op == builtin::call_extern() && node->args.size() >= 2) { + if (node->op == builtin::call_extern() && node->args.size() >= 3) { if (const auto *func_name = node->args[0].as<StringImmNode>()) { if (func_name->value == "AtomicAdd") { @@ - const BufferLoadNode *old_dst_node = - node->args[1].as<BufferLoadNode>(); - const BufferLoadNode *old_value_node = - node->args[2].as<BufferLoadNode>(); + auto extract_bl = [&](const PrimExpr &arg) -> const BufferLoadNode * { + if (const auto *bl = arg.as<BufferLoadNode>()) return bl; + if (const auto *c = arg.as<CallNode>()) { + if (c->op == builtin::address_of() && c->args.size() == 1) { + return c->args[0].as<BufferLoadNode>(); + } + } + return nullptr; + }; + const BufferLoadNode *old_dst_node = extract_bl(node->args[1]); + const BufferLoadNode *old_value_node = extract_bl(node->args[2]); if (!old_dst_node || !old_value_node) { return StmtExprMutator::VisitExpr_(node); }
🧹 Nitpick comments (1)
src/transform/atomicadd_vectorize.cc (1)
53-75: DRY: extract a shared helper to parse BufferLoad from AtomicAdd args.You now duplicate BufferLoad extraction logic in planner and rewriter. Consider a small static helper to centralize support for both direct and address_of forms.
Example helper (place in an anonymous namespace near the top of this file):
static inline const BufferLoadNode* ExtractBufferLoadArg(const PrimExpr& arg) { if (const auto* bl = arg.as<BufferLoadNode>()) return bl; if (const auto* c = arg.as<CallNode>()) { if (c->op == builtin::address_of() && c->args.size() == 1) { return c->args[0].as<BufferLoadNode>(); } } return nullptr; }Then use ExtractBufferLoadArg(node->args[i]) in both sites.
Also applies to: 208-222
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/atomicadd_vectorize.cc(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (1)
src/transform/loop_partition.cc (8)
node(41-48)node(41-41)node(49-56)node(49-49)node(112-121)node(112-112)node(171-179)node(171-171)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: bot-task
- GitHub Check: format-check
- GitHub Check: format-check
🔇 Additional comments (1)
src/transform/atomicadd_vectorize.cc (1)
263-276: AtomicAddx2/x4 address_of usage is correct
The CUDA headers define AtomicAddx2/x4 asTL_DEVICE void AtomicAddx2(half_t* ref, half_t* val); TL_DEVICE void AtomicAddx4(float* ref, float* val);so passing
address_of(BufferLoad)for both dst and value matches the ABI, and the vectorizer emits the non-returning variants withdtype=Handle, consistent with single-lane atomic calls.
| const auto *func_name = call->args[0].as<StringImmNode>(); | ||
| if (func_name->value == "AtomicAdd") { | ||
| DataType dtype = | ||
| call->args[1].as<CallNode>()->args[0].as<BufferLoadNode>()->dtype; | ||
| DataType dtype = call->args[1].as<BufferLoadNode>()->dtype; | ||
| vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential null deref: assumes args[1] is BufferLoad; old IR may still pass address_of(BufferLoad).
Dereferencing as()->dtype without guarding will crash when the arg is address_of(...).
Apply this diff to safely derive dtype from either form and only update when found:
- DataType dtype = call->args[1].as<BufferLoadNode>()->dtype;
- vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
+ if (const auto *bl = call->args[1].as<BufferLoadNode>()) {
+ vectorize_size_max =
+ GetVectorizeSizeMax(compute_capability, bl->dtype);
+ } else if (const auto *addr = call->args[1].as<CallNode>()) {
+ if (addr->op == builtin::address_of() && addr->args.size() == 1) {
+ if (const auto *bl2 = addr->args[0].as<BufferLoadNode>()) {
+ vectorize_size_max =
+ GetVectorizeSizeMax(compute_capability, bl2->dtype);
+ }
+ }
+ }🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 329 to 333, the code
assumes call->args[1] is a BufferLoadNode and dereferences it directly; change
this to detect and handle both BufferLoadNode and AddressOfNode(AddressOf
wrapping a BufferLoad), safely extract the inner BufferLoadNode when present,
check for null after as<> casts, and only call GetVectorizeSizeMax and assign
vectorize_size_max when a valid dtype is obtained.
- Introduced a new `loop_break` function in `customize.py` to allow breaking out of loops, returning a call to the `tl.loop_break` intrinsic. - Updated the `sync_threads` function in `builtin.py` to accept optional parameters for `barrier_id` and `arrive_count`, improving its flexibility for thread synchronization. - Added necessary imports in `__init__.py` to include the new `loop_break` function for broader accessibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
tilelang/language/__init__.py (1)
73-73: Remove unused noqa directive (RUF100).The linter reports this noqa as unused. Safe to drop.
Apply this diff:
- loop_break, # noqa: F401 + loop_break,tilelang/language/customize.py (2)
4-6: Drop unnecessary noqa on atomic re-exports (RUF100).The
# noqa: F401is reported as unused. Remove it.Apply this diff:
-from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 +from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store
154-161: Add return type annotation to loop_break for consistency.Other helpers (e.g., dp4a) annotate return types. Aligning improves readability and tooling.
Apply this diff:
-def loop_break(): +def loop_break() -> PrimExpr: """Break out of the current loop. Returns: tir.Call: A call to the `tl.loop_break` intrinsic. """ return T.call_intrin("handle", op.Op.get("tl.loop_break"))tilelang/language/builtin.py (1)
333-341: Refinesync_threadssignature and add guard forarrive_count
- Annotate
barrier_idandarrive_countasOptional[Union[int, PrimExpr]]and importOptional, Union.- Raise
ValueErrorifarrive_countis provided withoutbarrier_id.Apply:
-def sync_threads(barrier_id: int = None, arrive_count: int = None): +def sync_threads(barrier_id: Optional[Union[int, PrimExpr]] = None, + arrive_count: Optional[Union[int, PrimExpr]] = None): """Synchronize all threads in a block. """ - args = [] + args: list = [] + if arrive_count is not None and barrier_id is None: + raise ValueError("arrive_count requires barrier_id") if barrier_id is not None: args.append(barrier_id) if arrive_count is not None: args.append(arrive_count)And update imports:
from typing import Optional, Union, AnyVerified no existing calls pass
arrive_countalone.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
tilelang/language/__init__.py(1 hunks)tilelang/language/builtin.py(1 hunks)tilelang/language/customize.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/language/__init__.py (1)
tilelang/language/customize.py (1)
loop_break(154-160)
tilelang/language/builtin.py (2)
src/transform/inject_tma_barrier.cc (6)
barrier_id(107-109)barrier_id(107-107)barrier_id(111-115)barrier_id(111-111)barrier_id(193-202)barrier_id(193-193)tilelang/language/tir/op.py (1)
call_intrin(119-144)
tilelang/language/customize.py (3)
src/op/builtin.h (1)
tvm(13-348)tilelang/language/atomic.py (7)
atomic_max(20-63)atomic_min(66-111)atomic_add(114-223)atomic_addx2(226-261)atomic_addx4(264-299)atomic_load(302-338)atomic_store(341-391)tilelang/language/tir/op.py (1)
call_intrin(119-144)
🪛 Ruff (0.13.1)
tilelang/language/__init__.py
73-73: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/language/builtin.py
333-333: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
333-333: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
tilelang/language/customize.py
6-6: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-amd
as title.
Summary by CodeRabbit
New Features
Bug Fixes
Refactor
Tests
Examples
Chores