Skip to content

Conversation

@tzj-fxz
Copy link
Contributor

@tzj-fxz tzj-fxz commented Oct 20, 2025

Summary by CodeRabbit

  • New Features
    • Added memory-order configuration for atomic add operations, letting callers specify memory ordering semantics.
    • Vectorized and multi-element atomic add variants now accept an optional memory-order parameter.
    • Atomic add calls (including runtime/vectorized paths) propagate the memory-order value so it affects execution behavior.

@tzj-fxz tzj-fxz requested a review from LeiWang1999 October 20, 2025 14:11
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 20, 2025

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ❓ Inconclusive The PR title "[BugFix] Add memory order argument for non-vectorized atomic add" is related to a substantial part of the changeset but appears incomplete. The changes do add memory_order support to non-vectorized atomic add (via atomic_add.cc/h and builtin.cc), which the title accurately describes. However, the PR also introduces significant changes to vectorized atomic add paths, including memory_order parameter additions to AtomicAddx2, AtomicAddx4, and related functions in cuda/atomic.h, plus modifications to the vectorization pipeline in atomicadd_vectorize.cc and atomic.py. By explicitly stating "non-vectorized," the title narrows focus to only part of the changeset, potentially missing or downplaying the equally substantial vectorized-path modifications. The title should be revised to either: (1) remove the "non-vectorized" qualifier to reflect the full scope ("Add memory order argument for atomic add"), or (2) explicitly mention both non-vectorized and vectorized paths if the distinction is important. This will ensure the title accurately summarizes the complete set of changes and helps reviewers and future maintainers understand the full scope of the PR at a glance.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bc37ea6 and 8374859.

📒 Files selected for processing (6)
  • src/op/atomic_add.cc (2 hunks)
  • src/op/atomic_add.h (3 hunks)
  • src/op/builtin.cc (1 hunks)
  • src/tl_templates/cuda/atomic.h (3 hunks)
  • src/transform/atomicadd_vectorize.cc (2 hunks)
  • tilelang/language/atomic.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/op/atomic_add.h (4)
src/op/reduce.h (2)
  • SEqualReduce (41-43)
  • SEqualReduce (111-115)
src/op/copy.h (2)
  • SEqualReduce (117-122)
  • SEqualReduce (341-347)
src/op/gemm.h (2)
  • SEqualReduce (44-48)
  • SEqualReduce (147-160)
src/op/parallel.h (1)
  • SEqualReduce (80-84)
src/op/atomic_add.cc (1)
src/transform/atomicadd_vectorize.cc (4)
  • node (186-210)
  • node (186-186)
  • node (212-266)
  • node (212-212)
tilelang/language/atomic.py (1)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (9)
src/op/builtin.cc (1)

298-301: LGTM!

Correctly updates arity to accommodate the new memory_order parameter.

src/tl_templates/cuda/atomic.h (2)

108-116: LGTM!

Correctly restricts the fast-path to relaxed memory ordering, falling back to cuda::atomic_ref with explicit ordering for stricter semantics.


123-133: LGTM!

Correctly applies the same memory_order gating to the return-previous-value variant.

src/op/atomic_add.h (4)

25-25: LGTM!

New field follows the established pattern for atomic operation parameters.


45-46: LGTM!

Reflection API correctly exposes the new field as read-only.


49-56: LGTM!

Equality comparison correctly includes memory_order.


58-66: LGTM!

Hash computation correctly includes memory_order.

src/op/atomic_add.cc (2)

58-68: LGTM!

Constructor correctly reinterprets argument positions to accommodate memory_order, with appropriate default.


290-292: LGTM!

Correctly propagates memory_order to the atomic operation call.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
src/transform/atomicadd_vectorize.cc (1)

230-233: Complete the IntImm construction fix from previous review.

The condition >= 3 is now correct, but the default IntImm(0) still needs the explicit DataType argument as noted in the previous review.

Apply this diff:

       const IntImm memory_order =
-          node->args.size() >= 3 ? Downcast<IntImm>(node->args[2]) : IntImm(0);
+          node->args.size() >= 3 ? Downcast<IntImm>(node->args[2]) : IntImm(DataType::Int(32), 0);
🧹 Nitpick comments (2)
src/tl_templates/cuda/atomic.h (2)

108-110: Fix spacing around operator.

Missing space before && operator.

Apply this diff:

   if constexpr ((std::is_same_v<NT1, half> ||
-                 std::is_same_v<NT1, __nv_bfloat16>)&&memory_order ==
+                 std::is_same_v<NT1, __nv_bfloat16>) && memory_order ==
                 int(cuda::memory_order_relaxed)) {

123-125: Fix spacing around operator.

Missing space before && operator (same issue as lines 108-110).

Apply this diff:

   if constexpr ((std::is_same_v<NT1, half> ||
-                 std::is_same_v<NT1, __nv_bfloat16>)&&memory_order ==
+                 std::is_same_v<NT1, __nv_bfloat16>) && memory_order ==
                 int(cuda::memory_order_relaxed)) {
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8374859 and 58e8e73.

📒 Files selected for processing (3)
  • src/tl_templates/cuda/atomic.h (3 hunks)
  • src/transform/atomicadd_vectorize.cc (2 hunks)
  • tilelang/language/atomic.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/language/atomic.py
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (1)
src/transform/loop_vectorize.cc (16)
  • node (60-64)
  • node (60-60)
  • node (66-70)
  • node (66-66)
  • node (77-92)
  • node (77-77)
  • node (95-107)
  • node (95-95)
  • node (109-122)
  • node (109-109)
  • node (124-130)
  • node (124-124)
  • node (132-135)
  • node (132-132)
  • node (137-145)
  • node (137-137)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (1)
src/transform/atomicadd_vectorize.cc (1)

249-249: LGTM!

Correctly adds memory_order to the vectorized extern call arguments.

Comment on lines +135 to 191
// TODO add memory_order for vectorized atomic add
TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val,
int memory_order = int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<half2 *>(ref),
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
}

TL_DEVICE half2 AtomicAddx2Ret(half_t *ref, half_t *val) {
TL_DEVICE half2
AtomicAddx2Ret(half_t *ref, half_t *val,
int memory_order = int(cuda::memory_order_relaxed)) {
return atomicAdd(reinterpret_cast<half2 *>(ref),
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
}

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val) {
TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val,
int memory_order = int(cuda::memory_order_relaxed)) {
atomicAdd(
reinterpret_cast<__nv_bfloat162 *>(ref),
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
}

TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val) {
TL_DEVICE __nv_bfloat162
AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val,
int memory_order = int(cuda::memory_order_relaxed)) {
return atomicAdd(
reinterpret_cast<__nv_bfloat162 *>(ref),
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
}
#endif

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
TL_DEVICE void AtomicAddx2(float *ref, float *val) {
TL_DEVICE void AtomicAddx2(float *ref, float *val,
int memory_order = int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<float2 *>(ref),
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
}

TL_DEVICE float2 AtomicAddx2Ret(float *ref, float *val) {
TL_DEVICE float2
AtomicAddx2Ret(float *ref, float *val,
int memory_order = int(cuda::memory_order_relaxed)) {
return atomicAdd(reinterpret_cast<float2 *>(ref),
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
}

TL_DEVICE void AtomicAddx4(float *ref, float *val) {
TL_DEVICE void AtomicAddx4(float *ref, float *val,
int memory_order = int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<float4 *>(ref),
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
}

TL_DEVICE float4 AtomicAddx4Ret(float *ref, float *val) {
TL_DEVICE float4
AtomicAddx4Ret(float *ref, float *val,
int memory_order = int(cuda::memory_order_relaxed)) {
return atomicAdd(reinterpret_cast<float4 *>(ref),
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Vectorized functions accept but ignore memory_order parameter.

All vectorized atomic add functions (AtomicAddx2, AtomicAddx4, and their Ret variants) now accept a memory_order parameter but the implementations don't use it - they call the underlying CUDA atomicAdd without passing the memory order. This silently ignores the caller's memory ordering requirements and can lead to concurrency bugs.

The TODO comment on line 135 acknowledges this is incomplete. CUDA's vectorized atomicAdd may not support custom memory orders, but accepting and ignoring the parameter is dangerous.

Choose one of these approaches:

  1. Remove the parameter until proper support is implemented:
-TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val,
-                           int memory_order = int(cuda::memory_order_relaxed)) {
+TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val) {
  1. Assert relaxed-only and error for other memory orders:
 TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val,
                            int memory_order = int(cuda::memory_order_relaxed)) {
+  // Vectorized atomicAdd only supports relaxed memory order
+  assert(memory_order == int(cuda::memory_order_relaxed));
   atomicAdd(reinterpret_cast<half2 *>(ref),
             static_cast<half2>(*reinterpret_cast<half2 *>(val)));
 }
  1. Implement proper ordering by falling back to element-wise atomic ops when non-relaxed ordering is required.

Apply the chosen approach to all 8 vectorized functions (lines 136-191).

Committable suggestion skipped: line range outside the PR's diff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants