@@ -272,6 +272,8 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
272272 };
273273 // Step 1: try to infer loop's partition from a source fragment
274274 Buffer source_buffer, read_source_buffer;
275+ Buffer replicated_write_buffer; // Backup: fully replicated write buffer
276+
275277 for (const auto &[buffer, indices] : indice_map_) {
276278 if (T.layout_map .count (buffer)) {
277279 // skip reducers with rep=ALL
@@ -280,9 +282,26 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
280282 continue ;
281283
282284 auto frag = T.layout_map [buffer].as <Fragment>().value ();
285+ bool is_fully_replicated = buffer_is_completed_replicated (buffer);
286+ bool is_reducer = reducer_info_map_.count (buffer->data );
283287
284288 if (buffer_is_write_.count (buffer)) {
285- source_buffer = buffer;
289+ // Allow fully replicated write buffers if they are reducers
290+ // (reducers need replication for correctness)
291+ if (!is_fully_replicated || is_reducer) {
292+ // Prefer non-replicated write buffers, but also allow replicated
293+ // reducers
294+ if (!source_buffer.defined () ||
295+ (!is_fully_replicated && source_buffer.defined ())) {
296+ source_buffer = buffer;
297+ }
298+ } else if (!replicated_write_buffer.defined ()) {
299+ // Keep fully replicated NON-reducer write buffer as backup
300+ replicated_write_buffer = buffer;
301+ DLOG (INFO) << " Found fully replicated non-reducer write buffer "
302+ << buffer << " as backup for loop layout inference"
303+ << ' \n ' ;
304+ }
286305 } else {
287306 // Keep the buffer with largest number of indices
288307 // (which means the inference based on that buffer is more accurate)
@@ -308,6 +327,18 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
308327 Fragment src_layout = T.layout_map [buffer].as <Fragment>().value ();
309328 DLOG (INFO) << " [compute_loop_layout_from_buffer] infer from buffer `"
310329 << buffer << " ` of layout " << src_layout->DebugOutput () << ' \n ' ;
330+
331+ // Check if this buffer is a reducer
332+ bool is_reducer = reducer_info_map_.count (buffer->data );
333+
334+ // Defensive check: warn if attempting to infer from fully replicated buffer
335+ // But allow it for reducers (they need replication for correctness)
336+ if (src_layout->IsCompletedReplicated () && !is_reducer) {
337+ DLOG (WARNING)
338+ << " Attempting to infer loop layout from fully replicated buffer "
339+ << buffer << " , this may cause incorrect replication propagation" ;
340+ }
341+
311342 Fragment result;
312343 if (IsCommonAccessIndice (buffer)) {
313344 result = src_layout;
@@ -318,6 +349,34 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
318349 PrimExpr loop_var_to_thread =
319350 src_layout->ForwardThread (indice_map_[buffer], rep);
320351 loop_var_to_thread = analyzer_.Simplify (loop_var_to_thread);
352+
353+ // Check if loop_var_to_thread only depends on rep (not on loop_vars)
354+ // This indicates a fully replicated buffer that shouldn't determine loop
355+ // layout UNLESS it's a reducer (reducers need full replication)
356+ bool uses_loop_var = false ;
357+ PostOrderVisit (loop_var_to_thread, [&](const ObjectRef &objref) {
358+ if (auto var = objref.as <Var>()) {
359+ for (const auto &loop_var : loop_vars_) {
360+ if (var->same_as (loop_var->var )) {
361+ uses_loop_var = true ;
362+ break ;
363+ }
364+ }
365+ }
366+ });
367+
368+ if (!uses_loop_var && !is_reducer) {
369+ // loop_var_to_thread only depends on rep, not on loop variables
370+ // And it's not a reducer, so this is likely an index offset case
371+ DLOG (WARNING) << " Buffer " << buffer
372+ << " is fully replicated (not a reducer). "
373+ << " Cannot use it to infer loop layout. "
374+ << " loop_var_to_thread = " << loop_var_to_thread
375+ << " (only depends on rep)" ;
376+ throw LayoutConflictException (" Buffer is fully replicated and cannot "
377+ " be used for layout inference" );
378+ }
379+
321380 PostOrderVisit (loop_var_to_thread, [&](const ObjectRef &objref) {
322381 if (auto opt_var = objref.as <Var>();
323382 opt_var && inner_vars_.count (*opt_var)) {
@@ -334,6 +393,13 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
334393 << result->DebugOutput () << ' \n ' ;
335394 return result;
336395 };
396+
397+ // Try to infer loop layout from buffers in order of preference:
398+ // 1. Non-replicated write buffer (most reliable)
399+ // 2. Non-replicated read buffer
400+ // 3. Fully replicated write buffer (backup, may cause issues)
401+ // 4. Free inference mode (no source buffer)
402+
337403 if (source_buffer.defined ()) {
338404 loop_layout_ = compute_loop_layout_from_buffer (source_buffer);
339405 } else if (level == InferLevel::kFree ) {
@@ -388,7 +454,25 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
388454 auto rep = inv->Forward (fwd).back ();
389455 AddPredicate (EQ (rep, 0 ));
390456 }
391- } else {
457+ } else if (replicated_write_buffer.defined ()) {
458+ // Backup: try to use fully replicated write buffer
459+ // This may cause replication propagation, but it's better than failing
460+ DLOG (WARNING) << " Using fully replicated buffer "
461+ << replicated_write_buffer
462+ << " for loop layout inference as no other source buffer "
463+ " is available" ;
464+ try {
465+ loop_layout_ = compute_loop_layout_from_buffer (replicated_write_buffer);
466+ } catch (const LayoutConflictException &e) {
467+ // If fails, fall back to free mode
468+ DLOG (WARNING) << " Failed to infer from replicated buffer: " << e.what ()
469+ << " . Falling back to free mode" ;
470+ replicated_write_buffer = Buffer (); // Clear to trigger free mode below
471+ }
472+ }
473+
474+ if (!loop_layout_.defined ()) {
475+ // No source buffer available, use free mode inference
392476 // Vectorize Size must be aware of the buffer_remap
393477 // As the pass will do post processing to the layout
394478 auto maybe_remapped_root_ =
0 commit comments