Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Sep 24, 2025

as title.

Summary by CodeRabbit

  • New Features

    • CUDA device-side atomic utilities (scalar & vector x2/x4), memory-ordered atomic load/store; language-level atomic APIs; new loop_break export; new pass-config flag and builtin option to disable thread-storage synchronization.
  • Bug Fixes

    • Treat "local.var" buffers as local to ensure correct boundary/padding handling for safe memory accesses.
  • Refactor

    • Atomic helpers relocated to a dedicated module; CUDA atomic implementation consolidated; extern atomic add call argument handling simplified.
  • Tests

    • Comprehensive JIT tests for atomic operations added.
  • Examples

    • Reduced default sizes in several attention example scripts.
  • Chores

    • Updated third-party submodule commit reference.

- Introduced atomic functions including AtomicMax, AtomicMin, AtomicAdd, and their return variants for various data types.
- Implemented support for half, bfloat16, and float types with appropriate memory ordering.
- Moved atomic-related utilities from common.h to the new atomic.h file for better organization.
- Added Python bindings for atomic operations in tilelang, including atomic_max, atomic_min, atomic_add, and their vectorized counterparts.
- Updated customize.py to utilize the new atomic functions, enhancing modularity and maintainability.
- Reformatted atomic operation implementations in atomic.h for better code clarity.
- Adjusted function signatures in tilelang's atomic.py to enhance readability by aligning parameters.
- Cleaned up unnecessary whitespace and comments in customize.py to streamline the codebase.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 24, 2025

Caution

Review failed

The pull request is closed.

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a standalone CUDA atomic header and removes atomic helpers from common.h; introduces a TileLang Python atomic module (re-exported from customize.py); broadens local buffer scope detection; adds a pass flag to opt out of thread-storage synchronization; relaxes AtomicAdd argument/validation paths; adds atomic tests; updates example defaults and a submodule pointer.

Changes

Cohort / File(s) Summary of changes
CUDA atomics modularization
src/tl_templates/cuda/atomic.h, src/tl_templates/cuda/common.h
New atomic.h implements type normalization, cuda_cast, templated device atomic ops (AtomicMax/Min/Add and Ret variants), vectorized x2/x4 adds, and AtomicLoad/AtomicStore with arch-gated specializations; common.h now includes "atomic.h" and removed its prior atomic helper implementations.
Tile language atomic API split
tilelang/language/atomic.py, tilelang/language/customize.py, tilelang/language/__init__.py
New tilelang/language/atomic.py implements atomic_max/min/add, atomic_addx2/x4, atomic_load/store, and memory-order mapping; customize.py now imports/re-exports those functions and adds loop_break(); __init__ exports loop_break.
Safe-memory local scope tweak
src/transform/legalize_safe_memory_access.cc
SafeMemorysRewriter::IsLocalBuffer treats scope "local.var" as local for boundary/padding handling.
Thread storage sync opt-out
src/op/builtin.h, src/op/builtin.cc, src/transform/thread_storage_sync.cc, tilelang/transform/pass_config.py
Added kDisableThreadStorageSync constant and TL_DISABLE_THREAD_STORAGE_SYNC pass-config key; thread_storage_sync pass reads the flag and early-returns (skipping TileLangThreadSync) when set; public pass-config option registered.
AtomicAdd extern/arg & vectorize relaxation
src/op/atomic_add.cc, src/transform/atomicadd_vectorize.cc
MakeSIMTLoop now passes dst value directly to extern (removed address_of(dst)); vectorize planner/rewriter relax validation that previously required address_of(...) wrappers and now accepts direct BufferLoad args; dtype extraction simplified.
Tests for atomic ops
testing/python/language/test_tilelang_language_atomic_add.py
New JIT-wrapped kernels, runners, and tests covering atomic_add, atomic_max, atomic_min, atomic_load/store, memory-order behavior, x2/x4 adds, return-prev variants, and comprehensive verification.
Examples defaults reduced
examples/attention_sink/*.py
Reduced default args for several attention example entrypoints (batch, heads, seq_q, seq_kv).
Submodule update
3rdparty/tvm
Updated pinned commit hash for the submodule.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant TL_PY as tilelang.language.atomic
  participant TilePath as Tile-Region Path
  participant Extern as Extern Ops (C++)

  User->>TL_PY: atomic_add(dst, value, memory_order?, return_prev?)
  alt extents unknown / extern path
    TL_PY->>Extern: AtomicAdd / AtomicAddRet(dst, value, mo_id)
    Extern-->>TL_PY: (prev/new) value
  else extents known (tile-region path)
    TL_PY->>TilePath: Convert Buffer/Region → tile-region add
    Note right of TilePath: return_prev unsupported
    TilePath-->>TL_PY: side-effect only
  end
  TL_PY-->>User: Result or expression
Loading
sequenceDiagram
  autonumber
  participant Kernel as CUDA Kernel
  participant CUDA_ATOM as src/tl_templates/cuda/atomic.h
  participant HW as CUDA Hardware

  Kernel->>CUDA_ATOM: AtomicAdd<T1,T2>(addr, val, memory_order)
  alt half/bfloat16/vectorized (arch-guarded)
    CUDA_ATOM->>CUDA_ATOM: choose specialized path (reinterpret/intrinsic)
  else general types
    CUDA_ATOM->>CUDA_ATOM: use cuda::atomic_ref with memory_order
  end
  CUDA_ATOM->>HW: issue atomic operation
  HW-->>CUDA_ATOM: optional old value
  CUDA_ATOM-->>Kernel: void or previous value
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • tzj-fxz

Poem

"I nibble bytes where atoms convene,
half and bfloat dance in CUDA's sheen.
Externs hum, kernels hop with glee,
local.var snug, barriers optional — whee! 🐇✨"

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title “[Language] Support atomic add with ret” directly reflects the primary objective of adding atomic add operations that return a value in the language and concisely conveys the main change without extraneous details.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6495f76 and 06f8e3b.

📒 Files selected for processing (1)
  • testing/python/language/test_tilelang_language_atomic_add.py (3 hunks)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the atomic operation capabilities within the TileLang framework, primarily by adding support for atomic operations that return the previous value. This functionality is implemented in new CUDA device functions and exposed through a refactored and centralized Python API. The changes improve the expressiveness and utility of atomic operations for complex parallel programming patterns, while also enhancing code organization and ensuring correct memory access handling for local variables.

Highlights

  • Return-Value Atomic Operations: Introduced CUDA device functions (AtomicMaxRet, AtomicMinRet, AtomicAddRet, AtomicAddx2Ret, AtomicAddx4Ret) that return the value of the memory location before the atomic operation was performed.
  • Python API for Atomic Return Values: The tilelang.language.atomic module now exposes these return-value atomic operations via a return_prev parameter in functions like atomic_max, atomic_min, atomic_add, atomic_addx2, and atomic_addx4.
  • CUDA Atomic Utilities Refactoring: Moved all CUDA atomic device functions and helper types from src/tl_templates/cuda/common.h to a new, dedicated header src/tl_templates/cuda/atomic.h for better modularity.
  • Python Atomic API Refactoring: Consolidated Python wrappers for atomic operations into a new tilelang/language/atomic.py module, importing them into tilelang/language/customize.py.
  • Expanded Local Buffer Recognition: Updated the IsLocalBuffer check to include local.var scope, ensuring correct memory access legalization for such buffers.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for atomic operations that return the previous value, such as atomic_add_ret. The changes are well-structured, with the C++ CUDA implementations nicely refactored into a new atomic.h header and the Python bindings into a new atomic.py module.

The core logic for the new ...Ret functions in both C++ and Python seems correct. However, I've found a couple of critical issues in the Python wrappers for vectorized atomics (atomic_addx2 and atomic_addx4) concerning incorrect return type handling when return_prev is set to True. These could lead to incorrect behavior or runtime errors. Please see my detailed comments.

>>> atomic_addx2(global_grads[i, j:j+2], grads[i, j:j+2])
"""
func_name = "AtomicAddx2Ret" if return_prev else "AtomicAddx2"
return_type = dst.dtype if return_prev else "handle"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The return_type is incorrect when return_prev=True. dst.dtype provides the scalar element type (e.g., "float16"), but the underlying AtomicAddx2Ret C++ function returns a vector type (e.g., half2). The return_type should be a vector type string, which can be constructed by appending "x2" to the scalar type.

Suggested change
return_type = dst.dtype if return_prev else "handle"
return_type = dst.dtype + "x2" if return_prev else "handle"

>>> atomic_addx4(rgba_dst, rgba_add) # Atomic blend of all 4 channels
"""
func_name = "AtomicAddx4Ret" if return_prev else "AtomicAddx4"
return_type = "float4" if "float" in str(dst.dtype).lower() else "handle"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for determining return_type is incorrect for two reasons:

  1. It does not depend on return_prev. When return_prev=False, AtomicAddx4 is called which returns void, but return_type could be incorrectly set to "float4". It should be "handle".
  2. When return_prev=True for a non-float dst.dtype, it returns a "handle" instead of the previous value. Since atomic_addx4 is only implemented for float in C++, this case should probably raise an error instead of silently returning a wrong type.

The suggestion below fixes the first issue. For the second issue, you might consider raising a TypeError.

Suggested change
return_type = "float4" if "float" in str(dst.dtype).lower() else "handle"
return_type = "float4" if return_prev and "float" in str(dst.dtype).lower() else "handle"

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🧹 Nitpick comments (4)
tilelang/language/customize.py (1)

6-6: Drop unused noqa and explicitly re-export via all

Ruff flags the # noqa: F401 as unused. Make the re-exports explicit and remove the noqa to keep linters happy.

Apply:

-from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store  # noqa: F401
+from .atomic import (
+    atomic_max,
+    atomic_min,
+    atomic_add,
+    atomic_addx2,
+    atomic_addx4,
+    atomic_load,
+    atomic_store,
+)
+
+__all__ = [
+    "atomic_max",
+    "atomic_min",
+    "atomic_add",
+    "atomic_addx2",
+    "atomic_addx4",
+    "atomic_load",
+    "atomic_store",
+]
tilelang/language/atomic.py (1)

216-225: Behavior note: return_prev unsupported for tile-region path

Throwing NotImplementedError is fine for now. Consider documenting at call sites to avoid surprises.

  • Add a short note in the docstring of atomic_add near the tile-region description indicating return_prev is only supported for scalar/addressed path.
src/tl_templates/cuda/atomic.h (1)

33-41: bfloat16/half intrinsics headers

__float2half and __float2bfloat16 require proper headers. Relying on cuda_runtime.h transitively may be brittle.

Include explicit headers (behind the RTC guard as needed):

 #ifndef __CUDACC_RTC__
 #include <cuda_runtime.h>
+#include <cuda_fp16.h>
+#if defined(__CUDA_ARCH__) || defined(__CUDA_ARCH_LIST__)
+#include <cuda_bfloat16.h>
+#endif
 #endif

Please confirm CI covers both NVCC and NVRTC builds across supported CUDA versions.

src/transform/legalize_safe_memory_access.cc (1)

238-240: Treating "local.var" as local — LGTM; audit other scope branches

IsLocalBuffer change is correct. Repo search shows many direct comparisons to "local" or "local.fragment" that won't match "local.var" — update those checks or use IsLocalBuffer where appropriate.

Files to audit (non‑exhaustive): src/transform/loop_partition.cc, src/transform/layout_inference.cc, src/transform/common/loop_parallel_transform_utils.h, src/op/parallel.cc, src/op/reduce.cc, src/op/gemm.cc, src/op/gemm_sp.cc, src/op/fill.cc, src/op/copy.cc, src/op/atomic_add.cc

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fa4fd0b and dede99f.

📒 Files selected for processing (5)
  • src/tl_templates/cuda/atomic.h (1 hunks)
  • src/tl_templates/cuda/common.h (1 hunks)
  • src/transform/legalize_safe_memory_access.cc (1 hunks)
  • tilelang/language/atomic.py (1 hunks)
  • tilelang/language/customize.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/legalize_safe_memory_access.cc (1)
src/transform/storage_rewrite.cc (2)
  • scope (674-678)
  • scope (674-674)
tilelang/language/atomic.py (3)
tilelang/language/tir/op.py (3)
  • call_extern (172-194)
  • address_of (463-479)
  • call_intrin (119-144)
tilelang/language/frame.py (2)
  • has_let_value (189-198)
  • get_let_value (201-210)
tilelang/language/customize.py (4)
  • region (9-28)
  • buffer_to_tile_region (31-43)
  • buffer_region_to_tile_region (71-93)
  • buffer_load_to_tile_region (46-68)
tilelang/language/customize.py (2)
src/op/atomic_add.h (1)
  • tvm (12-84)
tilelang/language/atomic.py (7)
  • atomic_max (20-64)
  • atomic_min (67-113)
  • atomic_add (116-225)
  • atomic_addx2 (228-263)
  • atomic_addx4 (266-301)
  • atomic_load (304-341)
  • atomic_store (344-395)
🪛 Ruff (0.13.1)
tilelang/language/customize.py

6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: format-check
  • GitHub Check: format-check
🔇 Additional comments (3)
tilelang/language/customize.py (1)

4-6: Re-export move LGTM

Switching to re-exporting from tilelang.language.atomic keeps the public API stable and simplifies this module.

src/tl_templates/cuda/atomic.h (1)

23-27: Incorrect — CUDA_ARCH_LIST is an nvcc predefined macro

nvcc defines CUDA_ARCH_LIST (a comma-separated list of target CUDA_ARCH values); the repo already uses it across src/tl_templates/cuda/*. Keep the checks as-is — only add a fallback to CUDA_ARCH or provide a compatibility define if you need non-nvcc compiler support.

Likely an incorrect or invalid review comment.

tilelang/language/atomic.py (1)

10-17: Memory-order mapping is correct — no change required

cuda::memory_order uses the C/C++ memory-order enum with integer values 0..5 for relaxed→seq_cst, matching _MEMORY_ORDER_ID_MAP; src/tl_templates/cuda/atomic.h passes ints to cuda::memory_order, so the hard-coded mapping is consistent at runtime.

Comment on lines +7 to +9
#include <cuda/atomic>
#include <cutlass/numeric_types.h>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Missing include for std::is_same_v

You use std::is_same_v but don’t include <type_traits>.

Apply:

 #ifndef __CUDACC_RTC__
 #include <cuda_runtime.h>
 #endif
 
 #include <cuda/atomic>
 #include <cutlass/numeric_types.h>
+#include <type_traits>
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#include <cuda/atomic>
#include <cutlass/numeric_types.h>
#include <cuda/atomic>
#include <cutlass/numeric_types.h>
#include <type_traits>
🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 7 to 9, the file uses
std::is_same_v but doesn't include <type_traits>; add an #include for
<type_traits> near the existing includes (e.g., alongside <cuda/atomic> and
<cutlass/numeric_types.h>) so the template trait is available and the code
compiles.

Comment on lines +13 to +14
#define TL_DEVICE __forceinline__ __device__

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard TL_DEVICE to avoid redefinition with common.h

atomic.h defines TL_DEVICE, and common.h defines it too. Add a guard here.

-#define TL_DEVICE __forceinline__ __device__
+#ifndef TL_DEVICE
+#define TL_DEVICE __forceinline__ __device__
+#endif
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#define TL_DEVICE __forceinline__ __device__
#ifndef TL_DEVICE
#define TL_DEVICE __forceinline__ __device__
#endif
🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 13-14, TL_DEVICE is
unconditionally defined which conflicts with common.h; wrap the definition in a
macro guard so it is only defined if not already present (use #ifndef TL_DEVICE
... #define TL_DEVICE __forceinline__ __device__ ... #endif) to prevent
redefinition collisions.

Comment on lines 127 to 171
TL_DEVICE void AtomicAddx2(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half2 *>(address),
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
}

TL_DEVICE half2 AtomicAddx2Ret(half_t *address, half_t *val) {
return atomicAdd(reinterpret_cast<half2 *>(address),
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
}

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
TL_DEVICE void AtomicAddx2(bfloat16_t *address, bfloat16_t *val) {
atomicAdd(
reinterpret_cast<__nv_bfloat162 *>(address),
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
}

TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *address, bfloat16_t *val) {
return atomicAdd(
reinterpret_cast<__nv_bfloat162 *>(address),
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
}
#endif

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
TL_DEVICE void AtomicAddx2(float *address, float *val) {
atomicAdd(reinterpret_cast<float2 *>(address),
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
}

TL_DEVICE float2 AtomicAddx2Ret(float *address, float *val) {
return atomicAdd(reinterpret_cast<float2 *>(address),
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
}

TL_DEVICE void AtomicAddx4(float *address, float *val) {
atomicAdd(reinterpret_cast<float4 *>(address),
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
}

TL_DEVICE float4 AtomicAddx4Ret(float *address, float *val) {
return atomicAdd(reinterpret_cast<float4 *>(address),
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
}
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Vector atomic add overloads: alignment/aliasing assumptions

The reinterpret_casts assume address alignment to half2/__nv_bfloat162/float2/float4. If callers pass misaligned pointers (e.g., odd index), behavior is undefined.

  • Document alignment requirements.
  • Optionally assert alignment in debug builds:
    assert(reinterpret_cast<uintptr_t>(address) % alignof(half2) == 0);

Comment on lines 173 to 183
template <typename T> TL_DEVICE T AtomicLoad(T *address, int memory_order) {
cuda::atomic_ref<T, cuda::thread_scope_device> aref(*address);
return aref.load(cuda::memory_order(memory_order));
}

template <typename T1, typename T2>
TL_DEVICE void AtomicStore(T1 *address, T2 value, int memory_order) {
using NT1 = typename normalize_atomic_type<T1>::type;
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order));
} No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

🧩 Analysis chain

AtomicLoad/Store for half/bfloat16 types

Using cuda::atomic_ref<T> with T = half/__nv_bfloat16 may not be supported across all CUDA versions. You special-case half/bf16 elsewhere; consider mirroring here or documenting minimum CUDA version.

  • Add static_asserts or SFINAE constraints for supported types.
  • Alternatively, normalize to an integer storage type for load/store and cast.

🏁 Script executed:

#!/bin/bash
# Search for any prior handling patterns for half/bfloat atomics.
rg -nP 'atomic_ref<.*half|bfloat' -C2 src

Length of output: 39589


🏁 Script executed:

#!/bin/bash
set -euo pipefail
sed -n '1,260p' src/tl_templates/cuda/atomic.h

Length of output: 6616


Special-case half/__nv_bfloat16 in AtomicLoad/AtomicStore (src/tl_templates/cuda/atomic.h:173-183)

AtomicLoad currently instantiates cuda::atomic_ref with T not normalized; cuda::atomic_ref for half/__nv_bfloat16 is not supported across all CUDA versions — this file already special-cases half/__nv_bfloat16 for other atomics; mirror that here.

  • Use normalize_atomic_type and if constexpr to route half/__nv_bfloat16 to the intrinsic-based path (or provide dedicated overloads) instead of cuda::atomic_ref.
  • Add static_asserts/SFINAE to prevent instantiation on unsupported CUDA versions or document the minimum CUDA version required.
🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 173 to 183,
AtomicLoad/AtomicStore currently instantiate cuda::atomic_ref with the raw T
which breaks for half and __nv_bfloat16 on some CUDA versions; change these
functions to use normalize_atomic_type to obtain the normalized underlying type
for atomic_ref and add an if constexpr branch (or dedicated overloads) that
detects half/__nv_bfloat16 and routes those to the existing intrinsic-based path
instead of using cuda::atomic_ref; also add a static_assert or SFINAE guard that
prevents instantiation when the CUDA toolchain lacks support (or document the
minimum CUDA version) so consumers get a clear compile-time error.

#endif

#include <cuda/atomic>
#include "atomic.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Macro redefinition risk: TL_DEVICE defined in both atomic.h and common.h

atomic.h defines TL_DEVICE, and common.h defines it again (Line 30). This can cause redefinition warnings/errors.

Guard the definition in atomic.h:

-#define TL_DEVICE __forceinline__ __device__
+#ifndef TL_DEVICE
+#define TL_DEVICE __forceinline__ __device__
+#endif

(See atomic.h comment for precise patch.)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h (include at line 7) atomic.h already defines
TL_DEVICE and common.h redeclares it again around line 30, causing redefinition
warnings; fix this by updating atomic.h to guard its TL_DEVICE definition with
an include-guard style conditional (e.g. wrap the TL_DEVICE define in atomic.h
with #ifndef TL_DEVICE / #define TL_DEVICE ... #endif) so repeated includes
don’t redefine it, and remove or stop redeclaring TL_DEVICE in common.h if
present.

Comment on lines +199 to +203
assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(src_extent, dst_extent)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: incorrect extent computation (uses Python list max, not element-wise)

extent = max(src_extent, dst_extent) performs lexicographic list comparison, not per-dimension max or selection. This can produce wrong shapes and mis-legalized tile regions.

Minimal and safe fix: prefer the deduced dst extent when available, else src extent.

-    assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
-    src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
-    dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
-    extent = max(src_extent, dst_extent)
+    assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
+    src_extent = list(src_extent) if src_extent is not None else [1] * len(dst_extent)
+    dst_extent = list(dst_extent) if dst_extent is not None else [1] * len(src_extent)
+    # Prefer destination extents when both are present; otherwise, take the one we have.
+    extent = dst_extent if dst_extent is not None else src_extent
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(src_extent, dst_extent)
assert src_extent or dst_extent, "Can't deduce atomicadd extents from args"
src_extent = list(src_extent) if src_extent is not None else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent is not None else [1] * len(src_extent)
# Prefer destination extents when both are present; otherwise, take the one we have.
extent = dst_extent if dst_extent is not None else src_extent
🤖 Prompt for AI Agents
In tilelang/language/atomic.py around lines 199 to 203, the code uses extent =
max(src_extent, dst_extent) which does a lexicographic list comparison instead
of choosing the intended per-dimension or preferred extent; replace that with a
deterministic selection: after normalizing src_extent and dst_extent to lists,
set extent to dst_extent if it was provided (non-empty), otherwise set extent to
src_extent. Ensure extent is a list (already normalized above) and remove the
use of Python's max on lists.

- Introduced a new configuration option `tl.disable_thread_storage_sync` to control the automatic insertion of thread synchronization barriers in shared memory access.
- Updated the `ThreadSync` pass to check this configuration and bypass synchronization if disabled.
- Enhanced documentation in `builtin.h` and `pass_config.py` to clarify the purpose and usage of the new option.
- Simplified the retrieval of the thread storage sync configuration in the `ThreadSync` pass by removing unnecessary intermediate variables.
- Ensured that the inclusion of `builtin.h` is consistent by moving it to the appropriate location in the file.
- Updated atomic operations in CUDA templates to remove unnecessary address_of calls, enhancing performance and readability.
- Refactored atomic operation signatures in tilelang's atomic.py to accept references instead of pointers.
- Added new atomic operations and corresponding test cases for atomic add, max, min, and load/store functionalities in the testing suite.
- Updated the TVM subproject to the latest commit for better compatibility.
- Modified the `heads` parameter in both `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py` and `example_mha_sink_fwd_bhsd_wgmma_pipelined.py` from 1 to 32 to enhance performance in attention mechanisms.
- Ensured consistency across example scripts for improved usability and testing.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)

357-361: Align CLI defaults with the updated function defaults.
main() now defaults to batch=1 and seq lengths of 256, yet the argparse defaults (see Lines 423‑427) still advertise the old 8/4096 configuration. Unless the mismatch is intentional, please update the parser defaults—or add a brief note—so CLI users aren’t surprised by running a much heavier setup than the in-module defaults imply.

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1)

367-372: Keep the CLI defaults in sync with the new signature defaults.
With main() now defaulting to heads=32 and seq_q/seq_kv=256, the argparse defaults at Lines 438‑444 still point to the previous 64/4096 values. Please bring the parser defaults in line (or document the difference) so command-line runs match the lighter configuration you just introduced.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e11b4c4 and ac32bc1.

📒 Files selected for processing (3)
  • examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1 hunks)
  • examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1 hunks)
  • src/op/builtin.cc (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
  • main (133-189)
  • main (367-435)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
  • main (121-171)
  • main (232-284)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: format-check
  • GitHub Check: bot-task
🔇 Additional comments (1)
src/op/builtin.cc (1)

19-33: New pass config flag registration looks good.

The Bool option matches the new constant and downstream usage, keeping pass behavior opt-in without regressions.

- Simplified the extraction of buffer loads for atomic add operations by removing unnecessary address_of calls, improving code clarity and performance.
- Updated the data type retrieval for vectorization size calculation to directly access the buffer load node, enhancing efficiency.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/transform/atomicadd_vectorize.cc (2)

53-75: Off‑by‑one args.size() check and lack of backward‑compat parsing (BufferLoad vs address_of) can crash or miss vectorization.

  • node->args.size() is checked for >= 2 but args[2] is accessed → out‑of‑bounds when only 2 args exist.
  • Only direct BufferLoad is handled; older IR that wrapped arguments in address_of(BufferLoad) will be ignored, reducing vectorization opportunities.

Fix both by requiring >= 3 and extracting BufferLoad from either direct or address_of forms.

Apply this diff within this block:

-    if (node->op == builtin::call_extern() && node->args.size() >= 2) {
+    if (node->op == builtin::call_extern() && node->args.size() >= 3) {
       if (const auto *func_name = node->args[0].as<StringImmNode>()) {
         if (func_name->value == "AtomicAdd") {
-          const BufferLoadNode *buffer_load_dst =
-              node->args[1].as<BufferLoadNode>();
-          const BufferLoadNode *buffer_load_src =
-              node->args[2].as<BufferLoadNode>();
+          auto extract_bl = [&](const PrimExpr &arg) -> const BufferLoadNode * {
+            if (const auto *bl = arg.as<BufferLoadNode>()) return bl;
+            if (const auto *c = arg.as<CallNode>()) {
+              if (c->op == builtin::address_of() && c->args.size() == 1) {
+                return c->args[0].as<BufferLoadNode>();
+              }
+            }
+            return nullptr;
+          };
+          const BufferLoadNode *buffer_load_dst = extract_bl(node->args[1]);
+          const BufferLoadNode *buffer_load_src = extract_bl(node->args[2]);
           if (buffer_load_src && buffer_load_src->buffer.defined() &&
               buffer_load_dst && buffer_load_dst->buffer.defined()) {
 
             Buffer dst_buffer = buffer_load_dst->buffer;
             Array<PrimExpr> indices_dst = buffer_load_dst->indices;
             UpdateVectorSize(indices_dst, dst_buffer);
             Buffer src_buffer = buffer_load_src->buffer;
             Array<PrimExpr> indices_src = buffer_load_src->indices;
             UpdateVectorSize(indices_src, src_buffer);
           }
         }
       }
     }

208-222: Same args.size() bug and parsing gap in the rewriter; vectorization silently bails.

  • Uses args[2] with only >= 2 guard.
  • Only accepts direct BufferLoad; older address_of(BufferLoad) forms won’t match and vectorization won’t apply.

Harden the check and support both forms.

Apply this diff:

-      if (node->op == builtin::call_extern() && node->args.size() >= 2) {
+      if (node->op == builtin::call_extern() && node->args.size() >= 3) {
         if (const auto *func_name = node->args[0].as<StringImmNode>()) {
           if (func_name->value == "AtomicAdd") {
@@
-            const BufferLoadNode *old_dst_node =
-                node->args[1].as<BufferLoadNode>();
-            const BufferLoadNode *old_value_node =
-                node->args[2].as<BufferLoadNode>();
+            auto extract_bl = [&](const PrimExpr &arg) -> const BufferLoadNode * {
+              if (const auto *bl = arg.as<BufferLoadNode>()) return bl;
+              if (const auto *c = arg.as<CallNode>()) {
+                if (c->op == builtin::address_of() && c->args.size() == 1) {
+                  return c->args[0].as<BufferLoadNode>();
+                }
+              }
+              return nullptr;
+            };
+            const BufferLoadNode *old_dst_node = extract_bl(node->args[1]);
+            const BufferLoadNode *old_value_node = extract_bl(node->args[2]);
             if (!old_dst_node || !old_value_node) {
               return StmtExprMutator::VisitExpr_(node);
             }
🧹 Nitpick comments (1)
src/transform/atomicadd_vectorize.cc (1)

53-75: DRY: extract a shared helper to parse BufferLoad from AtomicAdd args.

You now duplicate BufferLoad extraction logic in planner and rewriter. Consider a small static helper to centralize support for both direct and address_of forms.

Example helper (place in an anonymous namespace near the top of this file):

static inline const BufferLoadNode* ExtractBufferLoadArg(const PrimExpr& arg) {
  if (const auto* bl = arg.as<BufferLoadNode>()) return bl;
  if (const auto* c = arg.as<CallNode>()) {
    if (c->op == builtin::address_of() && c->args.size() == 1) {
      return c->args[0].as<BufferLoadNode>();
    }
  }
  return nullptr;
}

Then use ExtractBufferLoadArg(node->args[i]) in both sites.

Also applies to: 208-222

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ac32bc1 and 68e8fed.

📒 Files selected for processing (1)
  • src/transform/atomicadd_vectorize.cc (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (1)
src/transform/loop_partition.cc (8)
  • node (41-48)
  • node (41-41)
  • node (49-56)
  • node (49-49)
  • node (112-121)
  • node (112-112)
  • node (171-179)
  • node (171-171)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: bot-task
  • GitHub Check: format-check
  • GitHub Check: format-check
🔇 Additional comments (1)
src/transform/atomicadd_vectorize.cc (1)

263-276: AtomicAddx2/x4 address_of usage is correct
The CUDA headers define AtomicAddx2/x4 as

TL_DEVICE void AtomicAddx2(half_t* ref, half_t* val);
TL_DEVICE void AtomicAddx4(float* ref, float* val);

so passing address_of(BufferLoad) for both dst and value matches the ABI, and the vectorizer emits the non-returning variants with dtype=Handle, consistent with single-lane atomic calls.

Comment on lines 329 to 333
const auto *func_name = call->args[0].as<StringImmNode>();
if (func_name->value == "AtomicAdd") {
DataType dtype =
call->args[1].as<CallNode>()->args[0].as<BufferLoadNode>()->dtype;
DataType dtype = call->args[1].as<BufferLoadNode>()->dtype;
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Potential null deref: assumes args[1] is BufferLoad; old IR may still pass address_of(BufferLoad).

Dereferencing as()->dtype without guarding will crash when the arg is address_of(...).

Apply this diff to safely derive dtype from either form and only update when found:

-          DataType dtype = call->args[1].as<BufferLoadNode>()->dtype;
-          vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
+          if (const auto *bl = call->args[1].as<BufferLoadNode>()) {
+            vectorize_size_max =
+                GetVectorizeSizeMax(compute_capability, bl->dtype);
+          } else if (const auto *addr = call->args[1].as<CallNode>()) {
+            if (addr->op == builtin::address_of() && addr->args.size() == 1) {
+              if (const auto *bl2 = addr->args[0].as<BufferLoadNode>()) {
+                vectorize_size_max =
+                    GetVectorizeSizeMax(compute_capability, bl2->dtype);
+              }
+            }
+          }
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 329 to 333, the code
assumes call->args[1] is a BufferLoadNode and dereferences it directly; change
this to detect and handle both BufferLoadNode and AddressOfNode(AddressOf
wrapping a BufferLoad), safely extract the inner BufferLoadNode when present,
check for null after as<> casts, and only call GetVectorizeSizeMax and assign
vectorize_size_max when a valid dtype is obtained.

- Introduced a new `loop_break` function in `customize.py` to allow breaking out of loops, returning a call to the `tl.loop_break` intrinsic.
- Updated the `sync_threads` function in `builtin.py` to accept optional parameters for `barrier_id` and `arrive_count`, improving its flexibility for thread synchronization.
- Added necessary imports in `__init__.py` to include the new `loop_break` function for broader accessibility.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
tilelang/language/__init__.py (1)

73-73: Remove unused noqa directive (RUF100).

The linter reports this noqa as unused. Safe to drop.

Apply this diff:

-    loop_break,  # noqa: F401
+    loop_break,
tilelang/language/customize.py (2)

4-6: Drop unnecessary noqa on atomic re-exports (RUF100).

The # noqa: F401 is reported as unused. Remove it.

Apply this diff:

-from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store  # noqa: F401
+from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store

154-161: Add return type annotation to loop_break for consistency.

Other helpers (e.g., dp4a) annotate return types. Aligning improves readability and tooling.

Apply this diff:

-def loop_break():
+def loop_break() -> PrimExpr:
     """Break out of the current loop.

     Returns:
         tir.Call: A call to the `tl.loop_break` intrinsic.
     """
     return T.call_intrin("handle", op.Op.get("tl.loop_break"))
tilelang/language/builtin.py (1)

333-341: Refine sync_threads signature and add guard for arrive_count

  • Annotate barrier_id and arrive_count as Optional[Union[int, PrimExpr]] and import Optional, Union.
  • Raise ValueError if arrive_count is provided without barrier_id.

Apply:

-def sync_threads(barrier_id: int = None, arrive_count: int = None):
+def sync_threads(barrier_id: Optional[Union[int, PrimExpr]] = None,
+                 arrive_count: Optional[Union[int, PrimExpr]] = None):
     """Synchronize all threads in a block.
     """
-    args = []
+    args: list = []
+    if arrive_count is not None and barrier_id is None:
+        raise ValueError("arrive_count requires barrier_id")
     if barrier_id is not None:
         args.append(barrier_id)
     if arrive_count is not None:
         args.append(arrive_count)

And update imports:

from typing import Optional, Union, Any

Verified no existing calls pass arrive_count alone.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 68e8fed and 6495f76.

📒 Files selected for processing (3)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/builtin.py (1 hunks)
  • tilelang/language/customize.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/language/__init__.py (1)
tilelang/language/customize.py (1)
  • loop_break (154-160)
tilelang/language/builtin.py (2)
src/transform/inject_tma_barrier.cc (6)
  • barrier_id (107-109)
  • barrier_id (107-107)
  • barrier_id (111-115)
  • barrier_id (111-111)
  • barrier_id (193-202)
  • barrier_id (193-193)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
tilelang/language/customize.py (3)
src/op/builtin.h (1)
  • tvm (13-348)
tilelang/language/atomic.py (7)
  • atomic_max (20-63)
  • atomic_min (66-111)
  • atomic_add (114-223)
  • atomic_addx2 (226-261)
  • atomic_addx4 (264-299)
  • atomic_load (302-338)
  • atomic_store (341-391)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
🪛 Ruff (0.13.1)
tilelang/language/__init__.py

73-73: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/language/builtin.py

333-333: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


333-333: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

tilelang/language/customize.py

6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-amd

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant