@@ -307,8 +307,10 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
307307 // (const index frag_a interacts with non-const index frag_b)
308308 // - No propagation needed: shared_a[i] = frag_a[0]
309309 // (const index frag_a with non-fragment buffer)
310+
310311 bool allow_layout_propgate =
311- fragment_buffers.size () > const_index_fragment_buffer.size ();
312+ const_index_fragment_buffer.empty () ||
313+ (fragment_buffers.size () > const_index_fragment_buffer.size ());
312314
313315 // Step 1: try to infer loop's partition from a source fragment
314316 Buffer source_buffer, read_source_buffer;
@@ -361,7 +363,15 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
361363 PrimExpr loop_var_to_thread =
362364 src_layout->ForwardThread (indice_map_[buffer], rep);
363365 loop_var_to_thread = analyzer_.Simplify (loop_var_to_thread);
364-
366+ PostOrderVisit (loop_var_to_thread, [&](const ObjectRef &objref) {
367+ if (auto opt_var = objref.as <Var>();
368+ opt_var && inner_vars_.count (*opt_var)) {
369+ std::ostringstream oss;
370+ oss << " loop_var_to_thread = " << loop_var_to_thread
371+ << " contains inner var" << *opt_var;
372+ throw LayoutConflictException (oss.str ());
373+ }
374+ });
365375 result = Fragment (loop_vars_, {}, loop_var_to_thread, rep_iter)
366376 ->BindThreadRange (T.thread_bounds );
367377 }
@@ -379,57 +389,46 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
379389 if (source_buffer.defined () && allow_layout_propgate) {
380390 loop_layout_ = compute_loop_layout_from_buffer (source_buffer);
381391 } else if (level == InferLevel::kFree ) {
392+ // For free layout inference
393+ // If replication exists and buffer has cross-thread shared memory access,
394+ // add predicate
395+ bool has_cross_thread_access = false ;
396+ PostOrderVisit (root_, [&](const ObjectRef &obj) {
397+ if (const auto *store = obj.as <BufferStoreNode>()) {
398+ // check if scope is shared or global
399+ if (store->buffer .scope () == " shared" ||
400+ store->buffer .scope () == " shared.dyn" ||
401+ store->buffer .scope () == " global" ) {
402+ has_cross_thread_access = true ;
403+ }
404+ } else if (const auto *load = obj.as <BufferLoadNode>()) {
405+ // check if scope is shared or global
406+ if (load->buffer .scope () == " shared" ||
407+ load->buffer .scope () == " shared.dyn" ||
408+ load->buffer .scope () == " global" ) {
409+ has_cross_thread_access = true ;
410+ }
411+ }
412+ });
413+
414+ // check if loop body contains a "pure" buffer store (i.e., direct
415+ // assignment, not compound update)
416+ bool has_pure_buffer_store = false ;
417+ PostOrderVisit (root_, [&](const ObjectRef &obj) {
418+ if (const auto *store = obj.as <BufferStoreNode>()) {
419+ // Check if the value is a direct load from another buffer (i.e., b[i]
420+ // = a[i])
421+ if (const auto *load = store->value .as <BufferLoadNode>()) {
422+ has_pure_buffer_store = true ;
423+ }
424+ }
425+ });
426+
382427 if (read_source_buffer.defined () && allow_layout_propgate) {
383428 loop_layout_ = compute_loop_layout_from_buffer (read_source_buffer);
384429 // // Loop don't need to be replicated.
385430 // if (!is_one(loop_layout_->ReplicateExtent()))
386431 // loop_layout_ = loop_layout_->DeReplicate();
387-
388- // For free layout inference
389- // If replication exists and buffer has cross-thread shared memory access,
390- // add predicate
391- bool has_cross_thread_access = false ;
392- PostOrderVisit (root_, [&](const ObjectRef &obj) {
393- if (const auto *store = obj.as <BufferStoreNode>()) {
394- // check if scope is shared or global
395- if (store->buffer .scope () == " shared" ||
396- store->buffer .scope () == " shared.dyn" ||
397- store->buffer .scope () == " global" ) {
398- has_cross_thread_access = true ;
399- }
400- } else if (const auto *load = obj.as <BufferLoadNode>()) {
401- // check if scope is shared or global
402- if (load->buffer .scope () == " shared" ||
403- load->buffer .scope () == " shared.dyn" ||
404- load->buffer .scope () == " global" ) {
405- has_cross_thread_access = true ;
406- }
407- }
408- });
409-
410- // check if loop body contains a "pure" buffer store (i.e., direct
411- // assignment, not compound update)
412- bool has_pure_buffer_store = false ;
413- PostOrderVisit (root_, [&](const ObjectRef &obj) {
414- if (const auto *store = obj.as <BufferStoreNode>()) {
415- // Check if the value is a direct load from another buffer (i.e., b[i]
416- // = a[i])
417- if (const auto *load = store->value .as <BufferLoadNode>()) {
418- has_pure_buffer_store = true ;
419- }
420- }
421- });
422-
423- if (!is_one (loop_layout_->ReplicateExtent ()) && has_cross_thread_access &&
424- !has_pure_buffer_store) {
425- auto inv = loop_layout_->Inverse ();
426- Array<PrimExpr> fwd;
427- for (size_t i = 0 ; i < loop_layout_->OutputDim (); i++)
428- fwd.push_back (0 );
429- fwd.push_back (InputPlaceholder (0 ) - T.thread_bounds ->min );
430- auto rep = inv->Forward (fwd).back ();
431- AddPredicate (EQ (rep, 0 ));
432- }
433432 }
434433
435434 if (!loop_layout_.defined ()) {
@@ -478,6 +477,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
478477 DLOG (INFO) << " [PlanLoopPartition] loop_layout_ = "
479478 << loop_layout_->DebugOutput () << ' \n ' ;
480479 }
480+ if (!is_one (loop_layout_->ReplicateExtent ()) && has_cross_thread_access &&
481+ !has_pure_buffer_store) {
482+ auto inv = loop_layout_->Inverse ();
483+ Array<PrimExpr> fwd;
484+ for (size_t i = 0 ; i < loop_layout_->OutputDim (); i++)
485+ fwd.push_back (0 );
486+ fwd.push_back (InputPlaceholder (0 ) - T.thread_bounds ->min );
487+ auto rep = inv->Forward (fwd).back ();
488+ AddPredicate (EQ (rep, 0 ));
489+ }
481490 } else {
482491 return {};
483492 }
0 commit comments