Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 57 additions & 48 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,10 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// (const index frag_a interacts with non-const index frag_b)
// - No propagation needed: shared_a[i] = frag_a[0]
// (const index frag_a with non-fragment buffer)

bool allow_layout_propgate =
fragment_buffers.size() > const_index_fragment_buffer.size();
const_index_fragment_buffer.empty() ||
(fragment_buffers.size() > const_index_fragment_buffer.size());

// Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer;
Expand Down Expand Up @@ -361,7 +363,15 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);

PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
if (auto opt_var = objref.as<Var>();
opt_var && inner_vars_.count(*opt_var)) {
std::ostringstream oss;
oss << "loop_var_to_thread = " << loop_var_to_thread
<< "contains inner var" << *opt_var;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix spacing in error message.

The error message is missing a space before "contains", which will concatenate the expression value with the text.

Apply this diff:

-        << "contains inner var" << *opt_var;
+        << " contains inner var " << *opt_var;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
<< "contains inner var" << *opt_var;
<< " contains inner var " << *opt_var;
🤖 Prompt for AI Agents
In src/op/parallel.cc around line 369, the error message concatenates the
previous expression value with the string "contains inner var" due to a missing
leading space; update the string so there is a space before "contains" (e.g.,
change to " contains inner var") ensuring the formatted output separates the
expression value and the message.

throw LayoutConflictException(oss.str());
}
});
result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
->BindThreadRange(T.thread_bounds);
}
Expand All @@ -379,57 +389,46 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
if (source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
} else if (level == InferLevel::kFree) {
// For free layout inference
// If replication exists and buffer has cross-thread shared memory access,
// add predicate
bool has_cross_thread_access = false;
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// check if scope is shared or global
if (store->buffer.scope() == "shared" ||
store->buffer.scope() == "shared.dyn" ||
store->buffer.scope() == "global") {
has_cross_thread_access = true;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
// check if scope is shared or global
if (load->buffer.scope() == "shared" ||
load->buffer.scope() == "shared.dyn" ||
load->buffer.scope() == "global") {
has_cross_thread_access = true;
}
}
});

// check if loop body contains a "pure" buffer store (i.e., direct
// assignment, not compound update)
bool has_pure_buffer_store = false;
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// Check if the value is a direct load from another buffer (i.e., b[i]
// = a[i])
if (const auto *load = store->value.as<BufferLoadNode>()) {
has_pure_buffer_store = true;
}
}
});

if (read_source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// // Loop don't need to be replicated.
// if (!is_one(loop_layout_->ReplicateExtent()))
// loop_layout_ = loop_layout_->DeReplicate();

// For free layout inference
// If replication exists and buffer has cross-thread shared memory access,
// add predicate
bool has_cross_thread_access = false;
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// check if scope is shared or global
if (store->buffer.scope() == "shared" ||
store->buffer.scope() == "shared.dyn" ||
store->buffer.scope() == "global") {
has_cross_thread_access = true;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
// check if scope is shared or global
if (load->buffer.scope() == "shared" ||
load->buffer.scope() == "shared.dyn" ||
load->buffer.scope() == "global") {
has_cross_thread_access = true;
}
}
});

// check if loop body contains a "pure" buffer store (i.e., direct
// assignment, not compound update)
bool has_pure_buffer_store = false;
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// Check if the value is a direct load from another buffer (i.e., b[i]
// = a[i])
if (const auto *load = store->value.as<BufferLoadNode>()) {
has_pure_buffer_store = true;
}
}
});

if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
!has_pure_buffer_store) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
}

if (!loop_layout_.defined()) {
Expand Down Expand Up @@ -478,6 +477,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
<< loop_layout_->DebugOutput() << '\n';
}
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
!has_pure_buffer_store) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
} else {
return {};
}
Expand Down
Loading