Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 78 additions & 18 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "parallel.h"

#include <algorithm>
#include <tvm/tir/op.h>

#include "../layout/utils.h"
Expand Down Expand Up @@ -413,22 +414,24 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,

// check if loop body contains a "pure" buffer store (i.e., direct
// assignment, not compound update)
bool has_pure_buffer_store = false;
std::vector<Buffer> store_shared_global_buffers, store_fragment_buffers;
// Buffers that scope is above fragments.
// global, shared, shared.dyn
// which can be used to analysis replicate case
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// Check if the value is a direct load from another buffer (i.e., b[i]
// = a[i])
if (const auto *load = store->value.as<BufferLoadNode>()) {
has_pure_buffer_store = true;
auto buffer = store->buffer;
if (buffer.scope() == "shared" || buffer.scope() == "shared.dyn" ||
buffer.scope() == "global") {
store_shared_global_buffers.emplace_back(buffer);
} else if (buffer.scope() == "local.fragment") {
store_fragment_buffers.emplace_back(buffer);
}
}
});

if (read_source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// // Loop don't need to be replicated.
// if (!is_one(loop_layout_->ReplicateExtent()))
// loop_layout_ = loop_layout_->DeReplicate();
}

if (!loop_layout_.defined()) {
Expand Down Expand Up @@ -477,16 +480,73 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
<< loop_layout_->DebugOutput() << '\n';
}
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
!has_pure_buffer_store) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}

// Lambda that guards replicated accesses:
// - When a loop layout replicates a fragment buffer (rep > 1), each thread
// observes the same fragment elements. Blindly storing to shared/global
// memory in that case would add the same value multiple times.
// - We therefore restrict the store so that only the replica with rep == 0
// performs the update (e.g. global[i] += fragment[i] only fires once).
// Trigger conditions for this guard:
// 1) There are cross-thread stores targeting shared/global memory (no
// fragment stores in this branch; atomic_add and similar remain TODO).
// 2) The loop layout replicate extent is greater than 1, inferred from the
// thread bounds captured in the layout.

[this, &store_shared_global_buffers, &store_fragment_buffers,
&has_cross_thread_access, &const_index_fragment_buffer, &T]() {
if (is_one(loop_layout_->ReplicateExtent()))
return;
if (!has_cross_thread_access)
return;

if (!store_fragment_buffers.empty()) {
// Iterate replicated fragment stores: when the fragment index is a
// constant (e.g. fragment[0]), every thread touches the same slot, so
// the rep == 0 predicate is unnecessary. Example: for i in
// T.Parallel(...):
// shared[i] = ...
// fragment[0] = ...
bool replicate_is_from_dynamic_index_fragment = false;
for (const auto &fragment : store_fragment_buffers) {
if (!T.layout_map.count(fragment)) {
continue;
}

auto fragment_layout = T.layout_map[fragment].as<Fragment>().value();
if (is_one(fragment_layout->ReplicateExtent()))
continue;

if (analyzer_.CanProveEqual(fragment_layout->ReplicateExtent(),
loop_layout_->ReplicateExtent()))
continue;
if (std::find(const_index_fragment_buffer.begin(),
const_index_fragment_buffer.end(),
fragment) == const_index_fragment_buffer.end()) {
replicate_is_from_dynamic_index_fragment = true;
}
}

if (!replicate_is_from_dynamic_index_fragment)
return;

ICHECK(store_shared_global_buffers.empty())
<< "Invalid layout: cannot have both fragment and shared store "
"buffers "
"in replicated loop layout.";
return;
} else {
// Now, store is global or shared
// or T.call_extern or T.call_intrin ...
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
}();
} else {
return {};
}
Expand Down
5 changes: 2 additions & 3 deletions src/transform/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
// == truncdiv(a + b*c, b) - c
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
// Skip analyzer simplification so we preserve straightforward div
// expressions.
PrimExpr offset_numerator = op->a + op->b * ceildiv;
PrimExpr offset_numerator =
analyzer_->Simplify(op->a + op->b * ceildiv);
return truncdiv(offset_numerator, op->b) - ceildiv;
}

Expand Down
Loading