@@ -413,22 +413,23 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
413413
414414    //  check if loop body contains a "pure" buffer store (i.e., direct
415415    //  assignment, not compound update)
416-     bool  has_pure_buffer_store = false ;
416+     std::vector<Buffer> store_shared_global_buffers, store_fragment_buffers, store_buffers; 
417+     //  Buffers that scope is above fragments.
418+     //  global, shared, shared.dyn
419+     //  which can be used to analysis replicate case
417420    PostOrderVisit (root_, [&](const  ObjectRef &obj) {
418421      if  (const  auto  *store = obj.as <BufferStoreNode>()) {
419-         //  Check if the value is a direct load from another buffer (i.e., b[i]
420-         //  = a[i])
421-         if  (const  auto  *load = store->value .as <BufferLoadNode>()) {
422-           has_pure_buffer_store = true ;
423-         }
422+         auto  buffer = store->buffer ;
423+         if  (buffer.scope () == " shared" scope () == " shared.dyn" scope () == " global" 
424+           store_shared_global_buffers.emplace_back (buffer);
425+         } else  if  (buffer.scope () == " local.fragment" 
426+           store_fragment_buffers.emplace_back (buffer);
427+         store_buffers.emplace_back (buffer);
424428      }
425429    });
426430
427431    if  (read_source_buffer.defined () && allow_layout_propgate) {
428432      loop_layout_ = compute_loop_layout_from_buffer (read_source_buffer);
429-       //  // Loop don't need to be replicated.
430-       //  if (!is_one(loop_layout_->ReplicateExtent()))
431-       //    loop_layout_ = loop_layout_->DeReplicate();
432433    }
433434
434435    if  (!loop_layout_.defined ()) {
@@ -477,16 +478,40 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
477478      DLOG (INFO) << " [PlanLoopPartition] loop_layout_ = " 
478479                 << loop_layout_->DebugOutput () << ' \n ' 
479480    }
480-     if  (!is_one (loop_layout_->ReplicateExtent ()) && has_cross_thread_access &&
481-         !has_pure_buffer_store) {
482-       auto  inv = loop_layout_->Inverse ();
483-       Array<PrimExpr> fwd;
484-       for  (size_t  i = 0 ; i < loop_layout_->OutputDim (); i++)
485-         fwd.push_back (0 );
486-       fwd.push_back (InputPlaceholder (0 ) - T.thread_bounds ->min );
487-       auto  rep = inv->Forward (fwd).back ();
488-       AddPredicate (EQ (rep, 0 ));
489-     }
481+ 
482+     //  Lambda that guards replicated accesses:
483+     //  - When a loop layout replicates a fragment buffer (rep > 1), each thread
484+     //    observes the same fragment elements. Blindly storing to shared/global
485+     //    memory in that case would add the same value multiple times.
486+     //  - We therefore restrict the store so that only the replica with rep == 0
487+     //    performs the update (e.g. global[i] += fragment[i] only fires once).
488+     //  Trigger conditions for this guard:
489+     //  1) There are cross-thread stores targeting shared/global memory (no
490+     //     fragment stores in this branch; atomic_add and similar remain TODO).
491+     //  2) The loop layout replicate extent is greater than 1, inferred from the
492+     //     thread bounds captured in the layout.
493+ 
494+     [this , &store_shared_global_buffers, &store_fragment_buffers, &has_cross_thread_access, &T](){
495+       if  (is_one (loop_layout_->ReplicateExtent ())) return ;
496+       if  (!has_cross_thread_access) return ;
497+ 
498+       if  (!store_fragment_buffers.empty ()) {
499+         ICHECK (store_shared_global_buffers.empty ())
500+             << " Invalid layout: cannot have both fragment and shared store buffers " 
501+               " in replicated loop layout." 
502+         return ;
503+       } else  {
504+         //  Now, store is global or shared
505+         //  or T.call_extern or T.call_intrin ...
506+         auto  inv = loop_layout_->Inverse ();
507+         Array<PrimExpr> fwd;
508+         for  (size_t  i = 0 ; i < loop_layout_->OutputDim (); i++)
509+           fwd.push_back (0 );
510+         fwd.push_back (InputPlaceholder (0 ) - T.thread_bounds ->min );
511+         auto  rep = inv->Forward (fwd).back ();
512+         AddPredicate (EQ (rep, 0 ));
513+       }
514+     }();
490515  } else  {
491516    return  {};
492517  }
0 commit comments