Skip to content

Commit f093445

Browse files
committed
re think replicate
1 parent 7043506 commit f093445

File tree

1 file changed

+44
-19
lines changed

1 file changed

+44
-19
lines changed

src/op/parallel.cc

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -413,22 +413,23 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
413413

414414
// check if loop body contains a "pure" buffer store (i.e., direct
415415
// assignment, not compound update)
416-
bool has_pure_buffer_store = false;
416+
std::vector<Buffer> store_shared_global_buffers, store_fragment_buffers, store_buffers;
417+
// Buffers that scope is above fragments.
418+
// global, shared, shared.dyn
419+
// which can be used to analysis replicate case
417420
PostOrderVisit(root_, [&](const ObjectRef &obj) {
418421
if (const auto *store = obj.as<BufferStoreNode>()) {
419-
// Check if the value is a direct load from another buffer (i.e., b[i]
420-
// = a[i])
421-
if (const auto *load = store->value.as<BufferLoadNode>()) {
422-
has_pure_buffer_store = true;
423-
}
422+
auto buffer = store->buffer;
423+
if (buffer.scope() == "shared" || buffer.scope() == "shared.dyn" || buffer.scope() == "global"){
424+
store_shared_global_buffers.emplace_back(buffer);
425+
} else if (buffer.scope() == "local.fragment")
426+
store_fragment_buffers.emplace_back(buffer);
427+
store_buffers.emplace_back(buffer);
424428
}
425429
});
426430

427431
if (read_source_buffer.defined() && allow_layout_propgate) {
428432
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
429-
// // Loop don't need to be replicated.
430-
// if (!is_one(loop_layout_->ReplicateExtent()))
431-
// loop_layout_ = loop_layout_->DeReplicate();
432433
}
433434

434435
if (!loop_layout_.defined()) {
@@ -477,16 +478,40 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
477478
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
478479
<< loop_layout_->DebugOutput() << '\n';
479480
}
480-
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
481-
!has_pure_buffer_store) {
482-
auto inv = loop_layout_->Inverse();
483-
Array<PrimExpr> fwd;
484-
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
485-
fwd.push_back(0);
486-
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
487-
auto rep = inv->Forward(fwd).back();
488-
AddPredicate(EQ(rep, 0));
489-
}
481+
482+
// Lambda that guards replicated accesses:
483+
// - When a loop layout replicates a fragment buffer (rep > 1), each thread
484+
// observes the same fragment elements. Blindly storing to shared/global
485+
// memory in that case would add the same value multiple times.
486+
// - We therefore restrict the store so that only the replica with rep == 0
487+
// performs the update (e.g. global[i] += fragment[i] only fires once).
488+
// Trigger conditions for this guard:
489+
// 1) There are cross-thread stores targeting shared/global memory (no
490+
// fragment stores in this branch; atomic_add and similar remain TODO).
491+
// 2) The loop layout replicate extent is greater than 1, inferred from the
492+
// thread bounds captured in the layout.
493+
494+
[this, &store_shared_global_buffers, &store_fragment_buffers, &has_cross_thread_access, &T](){
495+
if (is_one(loop_layout_->ReplicateExtent())) return;
496+
if (!has_cross_thread_access) return;
497+
498+
if (!store_fragment_buffers.empty()) {
499+
ICHECK(store_shared_global_buffers.empty())
500+
<< "Invalid layout: cannot have both fragment and shared store buffers "
501+
"in replicated loop layout.";
502+
return;
503+
} else {
504+
// Now, store is global or shared
505+
// or T.call_extern or T.call_intrin ...
506+
auto inv = loop_layout_->Inverse();
507+
Array<PrimExpr> fwd;
508+
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
509+
fwd.push_back(0);
510+
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
511+
auto rep = inv->Forward(fwd).back();
512+
AddPredicate(EQ(rep, 0));
513+
}
514+
}();
490515
} else {
491516
return {};
492517
}

0 commit comments

Comments
 (0)