-
Notifications
You must be signed in to change notification settings - Fork 290
[Reducer] Introduce alloc_reducer to separate inter and intra warp reduction
#757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
03bc5f6
e3c4619
fc44d44
4713bd8
deacc33
c30423c
4c0f978
6e40711
38bdc75
a23a370
3c8e3bc
19c0fc8
5401c94
fc2f55f
d9443a1
4c1aa26
d664a15
c1b68b9
67c1263
2431568
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -129,6 +129,14 @@ TVM_DLL const Op &tma_load_im2col(); | |||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||
| TVM_DLL const Op &tma_store(); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| /*! | ||||||||||||||||||||||||||||||
| * \brief tvm intrinsics for barrier initialization fence | ||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||
| * ptx_fence_barrier_init() | ||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||
| const Op &ptx_fence_barrier_init(); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
Comment on lines
+132
to
+139
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainExport macro missing on new intrinsic (will break dynamic linking/FFI) All neighboring intrinsics use TVM_DLL. The new declaration lacks it, risking missing symbol export from the shared library. Apply: -const Op &ptx_fence_barrier_init();
+TVM_DLL const Op &ptx_fence_barrier_init();Verify the op is also defined/registered (to avoid undefined reference at link/runtime): 🏁 Script executed: #!/bin/bash
# Expect at least one definition/registration for the new op
rg -n 'ptx_fence_barrier_init' -C2Length of output: 339 🏁 Script executed: #!/bin/bash
# Search for any implementation or registration of the new intrinsic
rg -n 'ptx_fence_barrier_init' -C2
rg -n 'ptx_fence_barrier_init' -C2 --glob '*.cc'
rg -n 'ptx_fence_barrier_init' -C2 --glob '*.cpp'
rg -n 'ptx_fence_barrier_init' -C2 --glob '*.c'Length of output: 784 #!/bin/bash
# List potential implementation files under src/op
echo "=== Source files in src/op with C/C++ extensions ==="
find src/op -type f | grep -E "\.(cc|cpp|c|cu|cuh)$"
# Search ignoring gitignore for any mention of the intrinsic name
echo "=== Searching for 'ptx_fence_barrier_init' across all files ==="
rg -uu -n "ptx_fence_barrier_init" -C2 .
# Broader search for any 'ptx_fence' identifiers
echo "=== Searching for 'ptx_fence' across all files ==="
rg -uu -n "ptx_fence" -C2 .
# Search for any global registrations in src/op
echo "=== Searching for TVM_REGISTER_GLOBAL in src/op ==="
rg -n "TVM_REGISTER_GLOBAL" -C2 src/op🏁 Script executed: #!/bin/bash
# Force search across all files (including those ignored by .gitignore) for the intrinsic name
echo "=== Full repo search for 'ptx_fence_barrier_init' ==="
rg -uu -n "ptx_fence_barrier_init" -C2 .
# Broader search for any fence barrier related symbols
echo "=== Full repo search for 'fence_barrier_init' ==="
rg -uu -n "fence_barrier_init" -C2 .
# Check for global registration patterns in the entire repo
echo "=== Search for TVM_REGISTER_GLOBAL ==="
rg -uu -n "TVM_REGISTER_GLOBAL" -C2 .Length of output: 1617 Add export macro and provide implementation/registration for ptx_fence_barrier_init The new intrinsic declaration in src/op/builtin.h is missing the TVM_DLL export specifier and has no matching implementation or TVM_REGISTER_GLOBAL call, which will lead to undefined references at link/runtime. • File src/op/builtin.h (line 138): the declaration - const Op &ptx_fence_barrier_init();
+ TVM_DLL const Op &ptx_fence_barrier_init();needs the TVM_DLL prefix to ensure the symbol is exported. 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||
| /*! | ||||||||||||||||||||||||||||||
| * \brief tvm intrinsics for mbarrier wait with parity bit | ||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,101 @@ | ||||||||||||||||||||||||||||||||||||
| /*! | ||||||||||||||||||||||||||||||||||||
| * \file src/op/finalize_reducer.cc | ||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||
| * Define finalize_reducer operator. | ||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| #include "finalize_reducer.h" | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| #include <tvm/arith/iter_affine_map.h> | ||||||||||||||||||||||||||||||||||||
| #include <tvm/tir/builtin.h> | ||||||||||||||||||||||||||||||||||||
| #include <tvm/tir/op.h> | ||||||||||||||||||||||||||||||||||||
| #include <tvm/tir/op_attr_types.h> | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| #include "../target/utils.h" | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| namespace tvm { | ||||||||||||||||||||||||||||||||||||
| namespace tl { | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| using namespace tir; | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) { | ||||||||||||||||||||||||||||||||||||
| auto node = make_object<FinalizeReducerOpNode>(); | ||||||||||||||||||||||||||||||||||||
| node->reducer = vmap[GetVarFromAccessPtr(args[0])]; | ||||||||||||||||||||||||||||||||||||
| node->op = (ReducerOpType)*as_const_int(args[1]); | ||||||||||||||||||||||||||||||||||||
| data_ = std::move(node); | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+21
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate builder arguments and vmap lookup. Prevent OOB access and undefined reducer buffer lookup. -FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
+FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
auto node = make_object<FinalizeReducerOpNode>();
- node->reducer = vmap[GetVarFromAccessPtr(args[0])];
- node->op = (ReducerOpType)*as_const_int(args[1]);
+ ICHECK_EQ(args.size(), 2) << "FinalizeReducer expects exactly 2 arguments";
+ Var data_var = GetVarFromAccessPtr(args[0]);
+ ICHECK(vmap.count(data_var)) << "Unknown reducer buffer var in access_ptr";
+ node->reducer = vmap[data_var];
+ const int64_t* op_i64 = as_const_int(args[1]);
+ ICHECK(op_i64) << "Second argument must be a constant integer (ReducerOpType)";
+ node->op = static_cast<ReducerOpType>(*op_i64);
data_ = std::move(node);
}📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T, | ||||||||||||||||||||||||||||||||||||
| arith::Analyzer *analyzer) const { | ||||||||||||||||||||||||||||||||||||
| auto buffer = T.buffer_remap[reducer]; | ||||||||||||||||||||||||||||||||||||
| auto opt_layout = T.layout_map.Get(reducer); | ||||||||||||||||||||||||||||||||||||
| ICHECK(opt_layout); | ||||||||||||||||||||||||||||||||||||
| ICHECK(opt_layout->as<Fragment>()); | ||||||||||||||||||||||||||||||||||||
| auto layout = opt_layout->as<Fragment>().value(); | ||||||||||||||||||||||||||||||||||||
| Array<PrimExpr> indices_0; | ||||||||||||||||||||||||||||||||||||
| indices_0.reserve(layout->OutputDim()); | ||||||||||||||||||||||||||||||||||||
| for (int i = 0; i < layout->OutputDim(); ++i) | ||||||||||||||||||||||||||||||||||||
| indices_0.push_back(Var("__finred_" + std::to_string(i))); | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| const int64_t *p_extent = as_const_int(layout->ReplicateExtent()); | ||||||||||||||||||||||||||||||||||||
| ICHECK(p_extent); | ||||||||||||||||||||||||||||||||||||
| int extent = *p_extent, scale = 1; | ||||||||||||||||||||||||||||||||||||
| ICHECK(extent == 1 || extent == *as_const_int(T.thread_bounds->extent)) | ||||||||||||||||||||||||||||||||||||
| << "Illegal finalize_reducer: extent=" << extent | ||||||||||||||||||||||||||||||||||||
| << "; T.thread_bounds=" << T.thread_bounds; | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
Comment on lines
+40
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid null deref on thread_bounds->extent. Cache total threads. Ensure extent is an IntImm and reuse it. - const int64_t *p_extent = as_const_int(layout->ReplicateExtent());
+ const int64_t *p_extent = as_const_int(layout->ReplicateExtent());
ICHECK(p_extent);
- int extent = *p_extent, scale = 1;
- ICHECK(extent == 1 || extent == *as_const_int(T.thread_bounds->extent))
+ int extent = *p_extent;
+ const int64_t* p_total = as_const_int(T.thread_bounds->extent);
+ ICHECK(p_total) << "T.thread_bounds->extent must be a constant integer";
+ int total_threads = static_cast<int>(*p_total);
+ ICHECK(extent == 1 || extent == total_threads)
<< "Illegal finalize_reducer: extent=" << extent
<< "; T.thread_bounds=" << T.thread_bounds;📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||
| if (extent == 1) | ||||||||||||||||||||||||||||||||||||
| return Evaluate(0); | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| std::array op_names{"tl::SumOp", "tl::MaxOp", "tl::MinOp"}; | ||||||||||||||||||||||||||||||||||||
| auto op_str = op_names[(int)op]; | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| // adopted from ReduceOp | ||||||||||||||||||||||||||||||||||||
| int reducing_threads = extent; | ||||||||||||||||||||||||||||||||||||
| std::stringstream ss; | ||||||||||||||||||||||||||||||||||||
| auto thread_offset = T.thread_bounds->min; | ||||||||||||||||||||||||||||||||||||
| if (TargetIsHopper(T.target)) { | ||||||||||||||||||||||||||||||||||||
| auto all_threads = T.thread_bounds->extent; | ||||||||||||||||||||||||||||||||||||
| ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1 | ||||||||||||||||||||||||||||||||||||
| << ", " << thread_offset << ", " << all_threads << ">::run_hopper"; | ||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||
| ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1 | ||||||||||||||||||||||||||||||||||||
| << ", " << thread_offset << ">::run"; | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| Array<PrimExpr> thread_reduce_args = {StringImm(ss.str()), | ||||||||||||||||||||||||||||||||||||
| BufferLoad(buffer, indices_0)}; | ||||||||||||||||||||||||||||||||||||
| if (reducing_threads >= 32) { | ||||||||||||||||||||||||||||||||||||
| PrimExpr workspace = | ||||||||||||||||||||||||||||||||||||
| T.AddWorkspace(*as_const_int(T.thread_bounds->extent), buffer->dtype); | ||||||||||||||||||||||||||||||||||||
| thread_reduce_args.push_back(workspace); | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| auto call = Call(buffer->dtype, builtin::call_extern(), thread_reduce_args); | ||||||||||||||||||||||||||||||||||||
| Stmt body = BufferStore(buffer, call, indices_0); | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| // make the outer spatial loop | ||||||||||||||||||||||||||||||||||||
| for (int i = layout->OutputDim() - 1; i >= 0; i--) { | ||||||||||||||||||||||||||||||||||||
| body = For(indices_0[i].as<Var>().value(), 0, layout->OutputShape()[i], | ||||||||||||||||||||||||||||||||||||
| ForKind::kParallel, body); | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| return body; | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T, | ||||||||||||||||||||||||||||||||||||
| InferLevel level) const { | ||||||||||||||||||||||||||||||||||||
| LayoutMap layout_map; | ||||||||||||||||||||||||||||||||||||
| layout_map.Set(reducer, T.layout_map.Get(reducer).value()); | ||||||||||||||||||||||||||||||||||||
| return layout_map; | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| TileOperator FinalizeReducerOpNode::Clone() const { | ||||||||||||||||||||||||||||||||||||
| auto node = make_object<FinalizeReducerOpNode>(*this); | ||||||||||||||||||||||||||||||||||||
| return TileOperator(node); | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer) | ||||||||||||||||||||||||||||||||||||
| .set_num_inputs(1) | ||||||||||||||||||||||||||||||||||||
| .set_attr<TCallEffectKind>("TCallEffectKind", | ||||||||||||||||||||||||||||||||||||
| Integer(CallEffectKind::kOpaque)); | ||||||||||||||||||||||||||||||||||||
| } // namespace tl | ||||||||||||||||||||||||||||||||||||
| } // namespace tvm | ||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| // Copyright (c) Tile-AI Corporation. | ||
| // Licensed under the MIT License. | ||
|
|
||
| /*! | ||
| * \file src/op/finalize_reducer.h | ||
| * \brief Define finalize_reducer operator. | ||
| */ | ||
|
|
||
| #ifndef TVM_TL_OP_FINALIZE_REDUCER_H_ | ||
| #define TVM_TL_OP_FINALIZE_REDUCER_H_ | ||
|
|
||
| #include "../transform/layout_reducer.h" | ||
| #include "./operator.h" | ||
|
|
||
| namespace tvm { | ||
| namespace tl { | ||
|
|
||
| using namespace tir; | ||
|
|
||
| class FinalizeReducerOpNode : public TileOperatorNode { | ||
| public: | ||
| tir::Buffer reducer; | ||
| ReducerOpType op; | ||
|
|
||
| static constexpr const char *_type_key = "tl.FinalizeReducerOp"; | ||
| TVM_DECLARE_FINAL_OBJECT_INFO(FinalizeReducerOpNode, TileOperatorNode); | ||
|
|
||
| Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; | ||
| LayoutMap InferLayout(const LayoutInferArgs &T, | ||
| InferLevel level) const override; | ||
| static const Op &Get(); | ||
| TileOperator Clone() const; | ||
| }; | ||
|
|
||
| class FinalizeReducerOp : public TileOperator { | ||
| public: | ||
| TVM_DEFINE_OBJECT_REF_METHODS(FinalizeReducerOp, TileOperator, | ||
| FinalizeReducerOpNode); | ||
| TVM_DLL FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap); | ||
| static const Op &Get(); | ||
| }; | ||
|
|
||
| } // namespace tl | ||
| } // namespace tvm | ||
|
|
||
| #endif // TVM_TL_OP_FINALIZE_REDUCER_H_ |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -124,6 +124,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| p->loop_vars_.push_back( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto reducer_info_map = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (reducer_info_map) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (auto &&[buffer, info] : reducer_info_map.value()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| p->reducer_info_map_.Set(buffer, info); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| StmtExprVisitor::VisitStmt_(op); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+127
to
133
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential null deref when reading reducer_info annotation Calling Apply: - auto reducer_info_map =
- op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
- if (reducer_info_map) {
- for (auto &&[buffer, info] : reducer_info_map.value())
- p->reducer_info_map_.Set(buffer, info);
- }
+ if (auto obj = op->annotations.Get(attr::kReducerInfo)) {
+ if (auto reducer_info_map = obj.value().as<Map<Var, ReducerInfo>>()) {
+ for (auto &&[buffer, info] : reducer_info_map.value()) {
+ p->reducer_info_map_.Set(buffer, info);
+ }
+ }
+ }📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
Comment on lines
+127
to
133
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check for null annotation before casting. The code assumes Apply this diff to add null checking: - auto reducer_info_map =
- op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
- if (reducer_info_map) {
- for (auto &&[buffer, info] : reducer_info_map.value())
- p->reducer_info_map_.Set(buffer, info);
- }
+ if (op->annotations.count(attr::kReducerInfo)) {
+ auto reducer_info_map =
+ op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
+ if (reducer_info_map) {
+ for (auto &&[buffer, info] : reducer_info_map.value())
+ p->reducer_info_map_.Set(buffer, info);
+ }
+ }📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -202,6 +208,11 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Buffer source_buffer, read_source_buffer; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (const auto &[buffer, indices] : indice_map_) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (T.layout_map.count(buffer)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // skip reducers with rep=ALL | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (auto info = reducer_info_map_.Get(buffer->data); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| info && info.value()->rep == ReducerRepType::ALL) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto frag = T.layout_map[buffer].as<Fragment>().value(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (buffer_is_write_.count(buffer)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| source_buffer = buffer; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -298,6 +309,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int vector_size = GetVectorizeSize(maybe_remapped_root_); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr loop_total_size = 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (Stmt l = root_; l.as<For>().has_value(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| l = l.as<For>().value()->body) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loop_total_size = loop_total_size * l.as<For>().value()->extent; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| while (!analyzer_.CanProve( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| floormod(loop_total_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T.thread_bounds->extent * vector_size) == 0) && | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vector_size > 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vector_size /= 2; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Check if coalesced_width is defined | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (auto coalesced_width = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| root_->annotations.Get(tl::attr::coalesced_width)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -343,11 +364,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (const auto &[buffer, _] : indice_map_) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (T.layout_map.count(buffer)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto fragment = T.layout_map[buffer].as<Fragment>().value(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // TODO: Add thread checks for replicated cases | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // need to wildcard match the rhs with lhs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!is_one(loop_layout_->ReplicateExtent()) || | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| !is_one(fragment->ReplicateExtent())) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto vars = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); }); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!ProveFragmentContains(loop_layout_, fragment, vars, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1134,10 +1134,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { | |
| } else if (op->op.same_as(tl::sync_grid())) { | ||
| this->need_cooperative_groups_ = true; | ||
| this->PrintIndent(); | ||
| this->stream << "cooperative_groups::grid_group grid = " | ||
| "cooperative_groups::this_grid();\n"; | ||
| this->PrintIndent(); | ||
| this->stream << "grid.sync();\n"; | ||
| this->stream << "cooperative_groups::this_grid().sync();\n"; | ||
| } else if (op->op.same_as(tl::loop_break())) { | ||
|
Comment on lines
1135
to
1138
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainDouble-check cooperative launch + device linking in build/runtime Grid-wide sync requires cooperative launches and separable compilation/device linking. This change doesn’t alter that requirement, but worth verifying CI/build and runtime paths still satisfy it.
Run this quick scan to confirm: 🏁 Script executed: #!/bin/bash
set -euo pipefail
echo "Occurrences of inline grid sync in CUDA codegen (context):"
rg -nC2 'this_grid\(\)\.sync\(\)' src || true
echo
echo "Check for cooperative kernel launch usage (CUDA/Driver API):"
rg -nP 'cudaLaunchCooperativeKernel|cuLaunchCooperativeKernel' -C2 -g '!**/third_party/**' || true
echo
echo "Check build files for device linking/separable compilation:"
rg -nP '(-rdc=|relocatable-device-code|SEPARABLE_COMPILATION|CMAKE_CUDA_SEPARABLE_COMPILATION)' -n -S -C2 -- \
CMakeLists.txt '**/*.cmake' || trueLength of output: 9214 🏁 Script executed: #!/bin/bash
set -euo pipefail
echo "Listing CMake build files (depth <= 3):"
find . -maxdepth 3 -type f \( -name 'CMakeLists.txt' -o -name '*.cmake' \)
echo
echo "Searching for -rdc flags in CMake build files:"
rg -nP '(-rdc\s*=\s*true|-rdc)' -g '*.cmake' -g 'CMakeLists.txt' || true
echo
echo "Searching for CMAKE_CUDA_SEPARABLE_COMPILATION:"
rg -nP 'CMAKE_CUDA_SEPARABLE_COMPILATION' -g '*.cmake' -g 'CMakeLists.txt' || true
echo
echo "Searching for relocatable-device-code keywords:"
rg -nP 'relocatable-device-code|SEPARABLE_COMPILATION' -g '*.cmake' -g 'CMakeLists.txt' || trueLength of output: 786 Ensure CUDA separable compilation and cooperative launches are properly enabled We didn’t find any Please update your build and CI to include:
At runtime, you’re already using
With these in place, 🤖 Prompt for AI Agents |
||
| this->PrintIndent(); | ||
| this->stream << "break;\n"; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Export the new intrinsic with TVM_DLL for consistency and visibility
All neighboring intrinsics are declared with
TVM_DLL. Omitting it forptx_fence_barrier_initmay prevent the symbol from being exported/visible when linking dynamically. Please align with the rest.Optionally, consider grouping this declaration next to other barrier-related intrinsics (e.g.,
tma_store_arrive,tma_store_wait) for discoverability.Run this quick check to ensure the intrinsic is implemented/registered and has a Python wrapper if needed:
🏁 Script executed:
Length of output: 190
Add TVM_DLL and register ptx_fence_barrier_init for end-to-end visibility
I don’t see any registration for
"tl.ptx_fence_barrier_init"or a Python binding anywhere in the repo—only the bare declaration insrc/op/builtin.h. To make this new intrinsic fully functional and exportable, please:Export the symbol
In
src/op/builtin.h(lines 135–138), update the declaration:Register the op in C++
Add a
TVM_REGISTER_GLOBAL("tl.ptx_fence_barrier_init")(e.g. insrc/op/builtin.cc) with the appropriateset_body_typed<…>(…)implementation so that the intrinsic can be looked up at runtime.Expose a Python binding
In the Python frontend (under
python/tvm/…or TOPI), add a wrapper forptx_fence_barrier_initso that it’s available to users scripting in Python.Optional: group with related intrinsics
Consider moving this declaration next to other barrier- or TMA-related intrinsics (e.g.
tma_store_arrive,tma_store_wait) for better discoverability.Without these steps, the symbol won’t be exported from the shared library nor available in the Python API, leading to link-time or runtime failures.
📝 Committable suggestion