Skip to content

Commit eed320f

Browse files
[Bugfix] Recover code for flexible parallel (tile-ai#1032)
* recover flex parallel process * lint fix --------- Co-authored-by: Zhiwen Mo <zm125@ic.ac.uk>
1 parent 1e8f0b1 commit eed320f

File tree

1 file changed

+57
-48
lines changed

1 file changed

+57
-48
lines changed

src/op/parallel.cc

Lines changed: 57 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,10 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
307307
// (const index frag_a interacts with non-const index frag_b)
308308
// - No propagation needed: shared_a[i] = frag_a[0]
309309
// (const index frag_a with non-fragment buffer)
310+
310311
bool allow_layout_propgate =
311-
fragment_buffers.size() > const_index_fragment_buffer.size();
312+
const_index_fragment_buffer.empty() ||
313+
(fragment_buffers.size() > const_index_fragment_buffer.size());
312314

313315
// Step 1: try to infer loop's partition from a source fragment
314316
Buffer source_buffer, read_source_buffer;
@@ -361,7 +363,15 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
361363
PrimExpr loop_var_to_thread =
362364
src_layout->ForwardThread(indice_map_[buffer], rep);
363365
loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
364-
366+
PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
367+
if (auto opt_var = objref.as<Var>();
368+
opt_var && inner_vars_.count(*opt_var)) {
369+
std::ostringstream oss;
370+
oss << "loop_var_to_thread = " << loop_var_to_thread
371+
<< "contains inner var" << *opt_var;
372+
throw LayoutConflictException(oss.str());
373+
}
374+
});
365375
result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
366376
->BindThreadRange(T.thread_bounds);
367377
}
@@ -379,57 +389,46 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
379389
if (source_buffer.defined() && allow_layout_propgate) {
380390
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
381391
} else if (level == InferLevel::kFree) {
392+
// For free layout inference
393+
// If replication exists and buffer has cross-thread shared memory access,
394+
// add predicate
395+
bool has_cross_thread_access = false;
396+
PostOrderVisit(root_, [&](const ObjectRef &obj) {
397+
if (const auto *store = obj.as<BufferStoreNode>()) {
398+
// check if scope is shared or global
399+
if (store->buffer.scope() == "shared" ||
400+
store->buffer.scope() == "shared.dyn" ||
401+
store->buffer.scope() == "global") {
402+
has_cross_thread_access = true;
403+
}
404+
} else if (const auto *load = obj.as<BufferLoadNode>()) {
405+
// check if scope is shared or global
406+
if (load->buffer.scope() == "shared" ||
407+
load->buffer.scope() == "shared.dyn" ||
408+
load->buffer.scope() == "global") {
409+
has_cross_thread_access = true;
410+
}
411+
}
412+
});
413+
414+
// check if loop body contains a "pure" buffer store (i.e., direct
415+
// assignment, not compound update)
416+
bool has_pure_buffer_store = false;
417+
PostOrderVisit(root_, [&](const ObjectRef &obj) {
418+
if (const auto *store = obj.as<BufferStoreNode>()) {
419+
// Check if the value is a direct load from another buffer (i.e., b[i]
420+
// = a[i])
421+
if (const auto *load = store->value.as<BufferLoadNode>()) {
422+
has_pure_buffer_store = true;
423+
}
424+
}
425+
});
426+
382427
if (read_source_buffer.defined() && allow_layout_propgate) {
383428
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
384429
// // Loop don't need to be replicated.
385430
// if (!is_one(loop_layout_->ReplicateExtent()))
386431
// loop_layout_ = loop_layout_->DeReplicate();
387-
388-
// For free layout inference
389-
// If replication exists and buffer has cross-thread shared memory access,
390-
// add predicate
391-
bool has_cross_thread_access = false;
392-
PostOrderVisit(root_, [&](const ObjectRef &obj) {
393-
if (const auto *store = obj.as<BufferStoreNode>()) {
394-
// check if scope is shared or global
395-
if (store->buffer.scope() == "shared" ||
396-
store->buffer.scope() == "shared.dyn" ||
397-
store->buffer.scope() == "global") {
398-
has_cross_thread_access = true;
399-
}
400-
} else if (const auto *load = obj.as<BufferLoadNode>()) {
401-
// check if scope is shared or global
402-
if (load->buffer.scope() == "shared" ||
403-
load->buffer.scope() == "shared.dyn" ||
404-
load->buffer.scope() == "global") {
405-
has_cross_thread_access = true;
406-
}
407-
}
408-
});
409-
410-
// check if loop body contains a "pure" buffer store (i.e., direct
411-
// assignment, not compound update)
412-
bool has_pure_buffer_store = false;
413-
PostOrderVisit(root_, [&](const ObjectRef &obj) {
414-
if (const auto *store = obj.as<BufferStoreNode>()) {
415-
// Check if the value is a direct load from another buffer (i.e., b[i]
416-
// = a[i])
417-
if (const auto *load = store->value.as<BufferLoadNode>()) {
418-
has_pure_buffer_store = true;
419-
}
420-
}
421-
});
422-
423-
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
424-
!has_pure_buffer_store) {
425-
auto inv = loop_layout_->Inverse();
426-
Array<PrimExpr> fwd;
427-
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
428-
fwd.push_back(0);
429-
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
430-
auto rep = inv->Forward(fwd).back();
431-
AddPredicate(EQ(rep, 0));
432-
}
433432
}
434433

435434
if (!loop_layout_.defined()) {
@@ -478,6 +477,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
478477
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
479478
<< loop_layout_->DebugOutput() << '\n';
480479
}
480+
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
481+
!has_pure_buffer_store) {
482+
auto inv = loop_layout_->Inverse();
483+
Array<PrimExpr> fwd;
484+
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
485+
fwd.push_back(0);
486+
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
487+
auto rep = inv->Forward(fwd).back();
488+
AddPredicate(EQ(rep, 0));
489+
}
481490
} else {
482491
return {};
483492
}

0 commit comments

Comments
 (0)