Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
03bc5f6
[Enhancement] Introduce finalize_reducer operator and layout reducer …
Aug 25, 2025
e3c4619
Refactor code formatting and improve readability in multiple files
LeiWang1999 Aug 25, 2025
fc44d44
Remove debug print statements from `copy.cc` and `inject_tma_barrier.…
LeiWang1999 Aug 25, 2025
4713bd8
[Enhancement] Disable reuse of small arrays in shared memory allocation
LeiWang1999 Aug 25, 2025
deacc33
Merge branch 'main' of https://github.com/tile-ai/tilelang into reduc…
LeiWang1999 Aug 25, 2025
c30423c
Refactor `setup.py` to remove duplicate logging statements and enhanc…
LeiWang1999 Aug 25, 2025
4c0f978
Refactor `finalize_reducer` and `reduce` functions to remove redundan…
LeiWang1999 Aug 25, 2025
6e40711
bug fix
LeiWang1999 Aug 25, 2025
38bdc75
Add thread checks workaround for replicated cases
kurisu6912 Aug 27, 2025
a23a370
Merge pull request #1 from kurisu6912/kurisu-fix-reducer-0825
LeiWang1999 Aug 27, 2025
3c8e3bc
Remove the is_one check
kurisu6912 Aug 27, 2025
19c0fc8
Merge pull request #2 from kurisu6912/kurisu-fix-reducer-0825
LeiWang1999 Aug 27, 2025
5401c94
fix lint error
kurisu6912 Aug 27, 2025
fc2f55f
Merge branch 'main' of https://github.com/tile-ai/tilelang into reduc…
LeiWang1999 Aug 29, 2025
d9443a1
lint fix
LeiWang1999 Aug 29, 2025
4c1aa26
Update autotune tests to use smaller matrix sizes for improved perfor…
LeiWang1999 Aug 29, 2025
d664a15
Merge branch 'main' of https://github.com/tile-ai/tilelang into reduc…
LeiWang1999 Aug 29, 2025
c1b68b9
[Refactor] Update FinalizeReducer to FinalizeReducerOp and adjust rel…
LeiWang1999 Aug 29, 2025
67c1263
[Refactor] Update atomic operations in common.h and modify test_examp…
LeiWang1999 Aug 29, 2025
2431568
[Refactor] Simplify CopyNode::LowerBulkCopy logic and update test exe…
LeiWang1999 Aug 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,6 @@ def build_cmake(self, ext):
if self.inplace:
extdir = os.path.abspath('./tilelang/lib/')

logger.info(f"{extdir=}")

# Prepare arguments for the CMake configuration step.
# -DCMAKE_LIBRARY_OUTPUT_DIRECTORY sets where built libraries go
# -DPYTHON_EXECUTABLE ensures that the correct Python is used
Expand Down
8 changes: 8 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ TVM_DLL const Op &tma_load_im2col();
*/
TVM_DLL const Op &tma_store();

/*!
* \brief tvm intrinsics for barrier initialization fence
*
* ptx_fence_barrier_init()
*
*/
const Op &ptx_fence_barrier_init();

Comment on lines +132 to +139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Export the new intrinsic with TVM_DLL for consistency and visibility

All neighboring intrinsics are declared with TVM_DLL. Omitting it for ptx_fence_barrier_init may prevent the symbol from being exported/visible when linking dynamically. Please align with the rest.

 /*!
  * \brief tvm intrinsics for barrier initialization fence
  *
  * ptx_fence_barrier_init()
  *
  */
-const Op &ptx_fence_barrier_init();
+TVM_DLL const Op &ptx_fence_barrier_init();

Optionally, consider grouping this declaration next to other barrier-related intrinsics (e.g., tma_store_arrive, tma_store_wait) for discoverability.

Run this quick check to ensure the intrinsic is implemented/registered and has a Python wrapper if needed:


🏁 Script executed:

#!/bin/bash
# Where is the intrinsic referenced/implemented?
rg -nP 'ptx_fence_barrier_init\s*\(' -S

# Is the op name registered/exposed anywhere (C++ or Python)?
rg -nP '"tl\.ptx_fence_barrier_init"' -S

Length of output: 190


Add TVM_DLL and register ptx_fence_barrier_init for end-to-end visibility

I don’t see any registration for "tl.ptx_fence_barrier_init" or a Python binding anywhere in the repo—only the bare declaration in src/op/builtin.h. To make this new intrinsic fully functional and exportable, please:

  • Export the symbol
    In src/op/builtin.h (lines 135–138), update the declaration:

    /*!
     * \brief tvm intrinsics for barrier initialization fence
     *
     * ptx_fence_barrier_init()
     *
     */
    -const Op &ptx_fence_barrier_init();
    +TVM_DLL const Op &ptx_fence_barrier_init();
  • Register the op in C++
    Add a TVM_REGISTER_GLOBAL("tl.ptx_fence_barrier_init") (e.g. in src/op/builtin.cc) with the appropriate set_body_typed<…>(…) implementation so that the intrinsic can be looked up at runtime.

  • Expose a Python binding
    In the Python frontend (under python/tvm/… or TOPI), add a wrapper for ptx_fence_barrier_init so that it’s available to users scripting in Python.

  • Optional: group with related intrinsics
    Consider moving this declaration next to other barrier- or TMA-related intrinsics (e.g. tma_store_arrive, tma_store_wait) for better discoverability.

Without these steps, the symbol won’t be exported from the shared library nor available in the Python API, leading to link-time or runtime failures.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/*!
* \brief tvm intrinsics for barrier initialization fence
*
* ptx_fence_barrier_init()
*
*/
const Op &ptx_fence_barrier_init();
/*!
* \brief tvm intrinsics for barrier initialization fence
*
* ptx_fence_barrier_init()
*
*/
TVM_DLL const Op &ptx_fence_barrier_init();

Comment on lines +132 to +139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Export macro missing on new intrinsic (will break dynamic linking/FFI)

All neighboring intrinsics use TVM_DLL. The new declaration lacks it, risking missing symbol export from the shared library.

Apply:

-const Op &ptx_fence_barrier_init();
+TVM_DLL const Op &ptx_fence_barrier_init();

Verify the op is also defined/registered (to avoid undefined reference at link/runtime):


🏁 Script executed:

#!/bin/bash
# Expect at least one definition/registration for the new op
rg -n 'ptx_fence_barrier_init' -C2

Length of output: 339


🏁 Script executed:

#!/bin/bash
# Search for any implementation or registration of the new intrinsic
rg -n 'ptx_fence_barrier_init' -C2
rg -n 'ptx_fence_barrier_init' -C2 --glob '*.cc'
rg -n 'ptx_fence_barrier_init' -C2 --glob '*.cpp'
rg -n 'ptx_fence_barrier_init' -C2 --glob '*.c'

Length of output: 784


#!/bin/bash
# List potential implementation files under src/op
echo "=== Source files in src/op with C/C++ extensions ==="
find src/op -type f | grep -E "\.(cc|cpp|c|cu|cuh)$"

# Search ignoring gitignore for any mention of the intrinsic name
echo "=== Searching for 'ptx_fence_barrier_init' across all files ==="
rg -uu -n "ptx_fence_barrier_init" -C2 .

# Broader search for any 'ptx_fence' identifiers
echo "=== Searching for 'ptx_fence' across all files ==="
rg -uu -n "ptx_fence" -C2 .

# Search for any global registrations in src/op
echo "=== Searching for TVM_REGISTER_GLOBAL in src/op ==="
rg -n "TVM_REGISTER_GLOBAL" -C2 src/op

🏁 Script executed:

#!/bin/bash
# Force search across all files (including those ignored by .gitignore) for the intrinsic name
echo "=== Full repo search for 'ptx_fence_barrier_init' ==="
rg -uu -n "ptx_fence_barrier_init" -C2 .

# Broader search for any fence barrier related symbols
echo "=== Full repo search for 'fence_barrier_init' ==="
rg -uu -n "fence_barrier_init" -C2 .

# Check for global registration patterns in the entire repo
echo "=== Search for TVM_REGISTER_GLOBAL ==="
rg -uu -n "TVM_REGISTER_GLOBAL" -C2 .

Length of output: 1617


Add export macro and provide implementation/registration for ptx_fence_barrier_init

The new intrinsic declaration in src/op/builtin.h is missing the TVM_DLL export specifier and has no matching implementation or TVM_REGISTER_GLOBAL call, which will lead to undefined references at link/runtime.

• File src/op/builtin.h (line 138): the declaration

- const Op &ptx_fence_barrier_init();
+ TVM_DLL const Op &ptx_fence_barrier_init();

needs the TVM_DLL prefix to ensure the symbol is exported.
• No implementation or registration found for ptx_fence_barrier_init in the repository—add the corresponding definition in a .cc/.cu file (e.g., src/op/ptx_fence.cc) and register it with TVM_REGISTER_GLOBAL (or the appropriate registration API) following the pattern of other PTX intrinsics.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/*!
* \brief tvm intrinsics for barrier initialization fence
*
* ptx_fence_barrier_init()
*
*/
const Op &ptx_fence_barrier_init();
/*!
* \brief tvm intrinsics for barrier initialization fence
*
* ptx_fence_barrier_init()
*
*/
TVM_DLL const Op &ptx_fence_barrier_init();

/*!
* \brief tvm intrinsics for mbarrier wait with parity bit
*
Expand Down
101 changes: 101 additions & 0 deletions src/op/finalize_reducer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*!
* \file src/op/finalize_reducer.cc
*
* Define finalize_reducer operator.
*/

#include "finalize_reducer.h"

#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/utils.h"

namespace tvm {
namespace tl {

using namespace tir;

FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
auto node = make_object<FinalizeReducerOpNode>();
node->reducer = vmap[GetVarFromAccessPtr(args[0])];
node->op = (ReducerOpType)*as_const_int(args[1]);
data_ = std::move(node);
}
Comment on lines +21 to +26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Validate builder arguments and vmap lookup.

Prevent OOB access and undefined reducer buffer lookup.

-FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
+FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
   auto node = make_object<FinalizeReducerOpNode>();
-  node->reducer = vmap[GetVarFromAccessPtr(args[0])];
-  node->op = (ReducerOpType)*as_const_int(args[1]);
+  ICHECK_EQ(args.size(), 2) << "FinalizeReducer expects exactly 2 arguments";
+  Var data_var = GetVarFromAccessPtr(args[0]);
+  ICHECK(vmap.count(data_var)) << "Unknown reducer buffer var in access_ptr";
+  node->reducer = vmap[data_var];
+  const int64_t* op_i64 = as_const_int(args[1]);
+  ICHECK(op_i64) << "Second argument must be a constant integer (ReducerOpType)";
+  node->op = static_cast<ReducerOpType>(*op_i64);
   data_ = std::move(node);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
auto node = make_object<FinalizeReducerOpNode>();
node->reducer = vmap[GetVarFromAccessPtr(args[0])];
node->op = (ReducerOpType)*as_const_int(args[1]);
data_ = std::move(node);
}
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
auto node = make_object<FinalizeReducerOpNode>();
ICHECK_EQ(args.size(), 2) << "FinalizeReducer expects exactly 2 arguments";
Var data_var = GetVarFromAccessPtr(args[0]);
ICHECK(vmap.count(data_var)) << "Unknown reducer buffer var in access_ptr";
node->reducer = vmap[data_var];
const int64_t* op_i64 = as_const_int(args[1]);
ICHECK(op_i64) << "Second argument must be a constant integer (ReducerOpType)";
node->op = static_cast<ReducerOpType>(*op_i64);
data_ = std::move(node);
}
🤖 Prompt for AI Agents
In src/op/finalize_reducer.cc around lines 21-26, the constructor assumes args
has at least 2 elements, that args[1] is a valid constant int, and that vmap
contains the reducer key returned by GetVarFromAccessPtr(args[0]); add explicit
validation: check args.size() >= 2 and throw or ICE with a clear message if not;
call as_const_int(args[1]) and verify it returns non-null before dereferencing,
then validate the resulting op index is within the expected enum range; compute
the reducer key from GetVarFromAccessPtr(args[0]) and ensure vmap contains that
key (use find/count) before indexing, throwing a descriptive error if missing;
only after these checks assign node->reducer and node->op and set data_.


Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
auto buffer = T.buffer_remap[reducer];
auto opt_layout = T.layout_map.Get(reducer);
ICHECK(opt_layout);
ICHECK(opt_layout->as<Fragment>());
auto layout = opt_layout->as<Fragment>().value();
Array<PrimExpr> indices_0;
indices_0.reserve(layout->OutputDim());
for (int i = 0; i < layout->OutputDim(); ++i)
indices_0.push_back(Var("__finred_" + std::to_string(i)));

const int64_t *p_extent = as_const_int(layout->ReplicateExtent());
ICHECK(p_extent);
int extent = *p_extent, scale = 1;
ICHECK(extent == 1 || extent == *as_const_int(T.thread_bounds->extent))
<< "Illegal finalize_reducer: extent=" << extent
<< "; T.thread_bounds=" << T.thread_bounds;

Comment on lines +40 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid null deref on thread_bounds->extent. Cache total threads.

Ensure extent is an IntImm and reuse it.

-  const int64_t *p_extent = as_const_int(layout->ReplicateExtent());
+  const int64_t *p_extent = as_const_int(layout->ReplicateExtent());
   ICHECK(p_extent);
-  int extent = *p_extent, scale = 1;
-  ICHECK(extent == 1 || extent == *as_const_int(T.thread_bounds->extent))
+  int extent = *p_extent;
+  const int64_t* p_total = as_const_int(T.thread_bounds->extent);
+  ICHECK(p_total) << "T.thread_bounds->extent must be a constant integer";
+  int total_threads = static_cast<int>(*p_total);
+  ICHECK(extent == 1 || extent == total_threads)
       << "Illegal finalize_reducer: extent=" << extent
       << "; T.thread_bounds=" << T.thread_bounds;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
const int64_t *p_extent = as_const_int(layout->ReplicateExtent());
ICHECK(p_extent);
int extent = *p_extent, scale = 1;
ICHECK(extent == 1 || extent == *as_const_int(T.thread_bounds->extent))
<< "Illegal finalize_reducer: extent=" << extent
<< "; T.thread_bounds=" << T.thread_bounds;
const int64_t *p_extent = as_const_int(layout->ReplicateExtent());
ICHECK(p_extent);
int extent = *p_extent;
const int64_t* p_total = as_const_int(T.thread_bounds->extent);
ICHECK(p_total) << "T.thread_bounds->extent must be a constant integer";
int total_threads = static_cast<int>(*p_total);
ICHECK(extent == 1 || extent == total_threads)
<< "Illegal finalize_reducer: extent=" << extent
<< "; T.thread_bounds=" << T.thread_bounds;
🤖 Prompt for AI Agents
In src/op/finalize_reducer.cc around lines 40 to 46, avoid a potential
null-dereference and redundant casts by validating that thread_bounds and its
extent are non-null and an IntImm, cache the total thread count once, and reuse
that value: check T.thread_bounds is non-null, obtain and dynamic_cast/isa the
extent to IntImm (ICHECK with a clear message if not), store the integer extent
in a local variable (e.g. total_threads) and reuse it in the subsequent
checks/logging instead of re-evaluating or re-casting.

if (extent == 1)
return Evaluate(0);

std::array op_names{"tl::SumOp", "tl::MaxOp", "tl::MinOp"};
auto op_str = op_names[(int)op];

// adopted from ReduceOp
int reducing_threads = extent;
std::stringstream ss;
auto thread_offset = T.thread_bounds->min;
if (TargetIsHopper(T.target)) {
auto all_threads = T.thread_bounds->extent;
ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1
<< ", " << thread_offset << ", " << all_threads << ">::run_hopper";
} else {
ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1
<< ", " << thread_offset << ">::run";
}
Array<PrimExpr> thread_reduce_args = {StringImm(ss.str()),
BufferLoad(buffer, indices_0)};
if (reducing_threads >= 32) {
PrimExpr workspace =
T.AddWorkspace(*as_const_int(T.thread_bounds->extent), buffer->dtype);
thread_reduce_args.push_back(workspace);
}
auto call = Call(buffer->dtype, builtin::call_extern(), thread_reduce_args);
Stmt body = BufferStore(buffer, call, indices_0);

// make the outer spatial loop
for (int i = layout->OutputDim() - 1; i >= 0; i--) {
body = For(indices_0[i].as<Var>().value(), 0, layout->OutputShape()[i],
ForKind::kParallel, body);
}

return body;
}

LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
LayoutMap layout_map;
layout_map.Set(reducer, T.layout_map.Get(reducer).value());
return layout_map;
}

TileOperator FinalizeReducerOpNode::Clone() const {
auto node = make_object<FinalizeReducerOpNode>(*this);
return TileOperator(node);
}

TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
46 changes: 46 additions & 0 deletions src/op/finalize_reducer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) Tile-AI Corporation.
// Licensed under the MIT License.

/*!
* \file src/op/finalize_reducer.h
* \brief Define finalize_reducer operator.
*/

#ifndef TVM_TL_OP_FINALIZE_REDUCER_H_
#define TVM_TL_OP_FINALIZE_REDUCER_H_

#include "../transform/layout_reducer.h"
#include "./operator.h"

namespace tvm {
namespace tl {

using namespace tir;

class FinalizeReducerOpNode : public TileOperatorNode {
public:
tir::Buffer reducer;
ReducerOpType op;

static constexpr const char *_type_key = "tl.FinalizeReducerOp";
TVM_DECLARE_FINAL_OBJECT_INFO(FinalizeReducerOpNode, TileOperatorNode);

Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
static const Op &Get();
TileOperator Clone() const;
};

class FinalizeReducerOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(FinalizeReducerOp, TileOperator,
FinalizeReducerOpNode);
TVM_DLL FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};

} // namespace tl
} // namespace tvm

#endif // TVM_TL_OP_FINALIZE_REDUCER_H_
26 changes: 21 additions & 5 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
p->loop_vars_.push_back(
IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
auto reducer_info_map =
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
if (reducer_info_map) {
for (auto &&[buffer, info] : reducer_info_map.value())
p->reducer_info_map_.Set(buffer, info);
}
StmtExprVisitor::VisitStmt_(op);
Comment on lines +127 to 133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential null deref when reading reducer_info annotation

Calling op->annotations.Get(attr::kReducerInfo)->as<...>() without checking that Get() returned a value can trigger a fatal when the annotation is absent. Use the same safe pattern you used elsewhere in this file for coalesced_width.

Apply:

-  auto reducer_info_map =
-      op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
-  if (reducer_info_map) {
-    for (auto &&[buffer, info] : reducer_info_map.value())
-      p->reducer_info_map_.Set(buffer, info);
-  }
+  if (auto obj = op->annotations.Get(attr::kReducerInfo)) {
+    if (auto reducer_info_map = obj.value().as<Map<Var, ReducerInfo>>()) {
+      for (auto &&[buffer, info] : reducer_info_map.value()) {
+        p->reducer_info_map_.Set(buffer, info);
+      }
+    }
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
auto reducer_info_map =
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
if (reducer_info_map) {
for (auto &&[buffer, info] : reducer_info_map.value())
p->reducer_info_map_.Set(buffer, info);
}
StmtExprVisitor::VisitStmt_(op);
if (auto obj = op->annotations.Get(attr::kReducerInfo)) {
if (auto reducer_info_map = obj.value().as<Map<Var, ReducerInfo>>()) {
for (auto &&[buffer, info] : reducer_info_map.value()) {
p->reducer_info_map_.Set(buffer, info);
}
}
}
StmtExprVisitor::VisitStmt_(op);
🤖 Prompt for AI Agents
In src/op/parallel.cc around lines 127 to 133, the code calls
op->annotations.Get(attr::kReducerInfo)->as<...>() directly which can
dereference a null optional if the annotation is absent; change this to first
capture the result of Get() into a local Optional<ObjectRef> (or similar) and
check it before calling as<...>(), then only iterate and Set entries when the
as<...>() call returns a valid Map; mirror the safe pattern used for
coalesced_width (check Get() result, then as(), then use value()) so no null
deref occurs.

Comment on lines +127 to 133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Check for null annotation before casting.

The code assumes attr::kReducerInfo exists in annotations and performs an unchecked cast. This could cause runtime errors if the annotation is missing.

Apply this diff to add null checking:

-  auto reducer_info_map =
-      op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
-  if (reducer_info_map) {
-    for (auto &&[buffer, info] : reducer_info_map.value())
-      p->reducer_info_map_.Set(buffer, info);
-  }
+  if (op->annotations.count(attr::kReducerInfo)) {
+    auto reducer_info_map =
+        op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
+    if (reducer_info_map) {
+      for (auto &&[buffer, info] : reducer_info_map.value())
+        p->reducer_info_map_.Set(buffer, info);
+    }
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
auto reducer_info_map =
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
if (reducer_info_map) {
for (auto &&[buffer, info] : reducer_info_map.value())
p->reducer_info_map_.Set(buffer, info);
}
StmtExprVisitor::VisitStmt_(op);
if (op->annotations.count(attr::kReducerInfo)) {
auto reducer_info_map =
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
if (reducer_info_map) {
for (auto &&[buffer, info] : reducer_info_map.value())
p->reducer_info_map_.Set(buffer, info);
}
}
StmtExprVisitor::VisitStmt_(op);
🤖 Prompt for AI Agents
In src/op/parallel.cc around lines 127 to 133, the code unconditionally
retrieves and casts the attr::kReducerInfo annotation which can be absent and
cause a runtime error; change the logic to first check that op->annotations
contains attr::kReducerInfo (or that Get(attr::kReducerInfo) is non-null) before
calling as<Map<Var, ReducerInfo>>(), only perform the cast and iterate when the
annotation is present, otherwise skip setting reducer_info_map_; ensure you
handle both optional/nullable returns consistently to avoid dereferencing null.

}

Expand Down Expand Up @@ -202,6 +208,11 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
Buffer source_buffer, read_source_buffer;
for (const auto &[buffer, indices] : indice_map_) {
if (T.layout_map.count(buffer)) {
// skip reducers with rep=ALL
if (auto info = reducer_info_map_.Get(buffer->data);
info && info.value()->rep == ReducerRepType::ALL)
continue;

auto frag = T.layout_map[buffer].as<Fragment>().value();
if (buffer_is_write_.count(buffer)) {
source_buffer = buffer;
Expand Down Expand Up @@ -298,6 +309,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_);

PrimExpr loop_total_size = 1;
for (Stmt l = root_; l.as<For>().has_value();
l = l.as<For>().value()->body)
loop_total_size = loop_total_size * l.as<For>().value()->extent;
while (!analyzer_.CanProve(
floormod(loop_total_size,
T.thread_bounds->extent * vector_size) == 0) &&
vector_size > 1)
vector_size /= 2;

// Check if coalesced_width is defined
if (auto coalesced_width =
root_->annotations.Get(tl::attr::coalesced_width)) {
Expand Down Expand Up @@ -343,11 +364,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) {
auto fragment = T.layout_map[buffer].as<Fragment>().value();
// TODO: Add thread checks for replicated cases
// need to wildcard match the rhs with lhs
if (!is_one(loop_layout_->ReplicateExtent()) ||
!is_one(fragment->ReplicateExtent()))
continue;
auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
if (!ProveFragmentContains(loop_layout_, fragment, vars,
Expand Down
5 changes: 4 additions & 1 deletion src/op/parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
#include <tvm/tir/stmt_functor.h>

#include "../layout/layout.h"
#include "operator.h"
#include "../transform/layout_reducer.h"
#include "./operator.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -112,6 +113,8 @@ class ParallelOpNode : public TileOperatorNode {
Array<IterVar> loop_vars_;
// Analyzer for simplifying and analyzing expressions, mutable for lazy use.
mutable arith::Analyzer analyzer_;
// Mapping from buffer to reducer info.
Map<Var, ReducerInfo> reducer_info_map_;
};

class ParallelOp : public TileOperator {
Expand Down
4 changes: 2 additions & 2 deletions src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "../layout/utils.h"
#include "../op/parallel.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "tir/transforms/ir_utils.h"

Expand Down Expand Up @@ -237,9 +238,8 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int reducing_threads = (*extent) * (*scale);
std::stringstream ss;

bool has_arch = T.target->attrs.count("arch") > 0;
auto thread_offset = T.thread_bounds->min;
if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") {
if (TargetIsHopper(T.target)) {
auto all_threads = T.thread_bounds->extent;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ", " << thread_offset
Expand Down
5 changes: 1 addition & 4 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1134,10 +1134,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(tl::sync_grid())) {
this->need_cooperative_groups_ = true;
this->PrintIndent();
this->stream << "cooperative_groups::grid_group grid = "
"cooperative_groups::this_grid();\n";
this->PrintIndent();
this->stream << "grid.sync();\n";
this->stream << "cooperative_groups::this_grid().sync();\n";
} else if (op->op.same_as(tl::loop_break())) {
Comment on lines 1135 to 1138
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Double-check cooperative launch + device linking in build/runtime

Grid-wide sync requires cooperative launches and separable compilation/device linking. This change doesn’t alter that requirement, but worth verifying CI/build and runtime paths still satisfy it.

  • Build flags: NVCC should compile with device linking enabled (e.g., -rdc=true or CMake separable compilation).
  • Runtime: Kernels that may hit tl::sync_grid() should be launched with cooperative APIs (e.g., cudaLaunchCooperativeKernel / cuLaunchCooperativeKernel) on devices reporting cooperative support.

Run this quick scan to confirm:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Occurrences of inline grid sync in CUDA codegen (context):"
rg -nC2 'this_grid\(\)\.sync\(\)' src || true

echo
echo "Check for cooperative kernel launch usage (CUDA/Driver API):"
rg -nP 'cudaLaunchCooperativeKernel|cuLaunchCooperativeKernel' -C2 -g '!**/third_party/**' || true

echo
echo "Check build files for device linking/separable compilation:"
rg -nP '(-rdc=|relocatable-device-code|SEPARABLE_COMPILATION|CMAKE_CUDA_SEPARABLE_COMPILATION)' -n -S -C2 -- \
  CMakeLists.txt '**/*.cmake' || true

Length of output: 9214


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Listing CMake build files (depth <= 3):"
find . -maxdepth 3 -type f \( -name 'CMakeLists.txt' -o -name '*.cmake' \)

echo
echo "Searching for -rdc flags in CMake build files:"
rg -nP '(-rdc\s*=\s*true|-rdc)' -g '*.cmake' -g 'CMakeLists.txt' || true

echo
echo "Searching for CMAKE_CUDA_SEPARABLE_COMPILATION:"
rg -nP 'CMAKE_CUDA_SEPARABLE_COMPILATION' -g '*.cmake' -g 'CMakeLists.txt' || true

echo
echo "Searching for relocatable-device-code keywords:"
rg -nP 'relocatable-device-code|SEPARABLE_COMPILATION' -g '*.cmake' -g 'CMakeLists.txt' || true

Length of output: 786


Ensure CUDA separable compilation and cooperative launches are properly enabled

We didn’t find any -rdc=true, CMAKE_CUDA_SEPARABLE_COMPILATION, or equivalent relocatable-device-code settings in your CMake files, which are required whenever device code uses cooperative_groups::this_grid().sync(). Without these, kernels containing grid-wide sync won’t link or launch correctly.

Please update your build and CI to include:

  • In your top-level CMakeLists.txt (or wherever you enable CUDA):
    enable_language(CUDA)
    set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
    # Or explicitly:
    # set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -rdc=true")
  • For targets using CUDA device code:
    target_compile_options(<your_cuda_target>
      PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-rdc=true>
    )
  • Verify your CI/system build invocation passes these flags (e.g. -DCMAKE_CUDA_SEPARABLE_COMPILATION=ON).

At runtime, you’re already using cudaLaunchCooperativeKernel in wrapper.py (lines 328–331), and your codegen emits cooperative_groups::this_grid().sync();. Double-check that:

  • The device reports CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH before attempting a grid sync.
  • All kernels containing grid sync are indeed launched via the cooperative APIs.

With these in place, this_grid().sync() will link and run correctly.

🤖 Prompt for AI Agents
In src/target/codegen_cuda.cc around lines 1135 to 1138, the code emits
cooperative_groups::this_grid().sync() but the build/CI likely lacks CUDA
separable compilation and relocatable device code flags; enable separable
compilation in CMake (e.g. enable_language(CUDA) and set
CMAKE_CUDA_SEPARABLE_COMPILATION ON or add -rdc=true to CUDA flags), add
target-specific compile options (target_compile_options(<cuda_target> PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-rdc=true>)), ensure CI build invocations propagate
-DCMAKE_CUDA_SEPARABLE_COMPILATION=ON or equivalent, and confirm at runtime that
CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH is supported and kernels with grid sync
are launched via cooperative launch APIs (as used in wrapper.py).

this->PrintIndent();
this->stream << "break;\n";
Expand Down
Loading