Skip to content

Commit 6464a9d

Browse files
committed
updt
1 parent 5c4533f commit 6464a9d

File tree

1 file changed

+86
-2
lines changed

1 file changed

+86
-2
lines changed

src/op/parallel.cc

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
272272
};
273273
// Step 1: try to infer loop's partition from a source fragment
274274
Buffer source_buffer, read_source_buffer;
275+
Buffer replicated_write_buffer; // Backup: fully replicated write buffer
276+
275277
for (const auto &[buffer, indices] : indice_map_) {
276278
if (T.layout_map.count(buffer)) {
277279
// skip reducers with rep=ALL
@@ -280,9 +282,26 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
280282
continue;
281283

282284
auto frag = T.layout_map[buffer].as<Fragment>().value();
285+
bool is_fully_replicated = buffer_is_completed_replicated(buffer);
286+
bool is_reducer = reducer_info_map_.count(buffer->data);
283287

284288
if (buffer_is_write_.count(buffer)) {
285-
source_buffer = buffer;
289+
// Allow fully replicated write buffers if they are reducers
290+
// (reducers need replication for correctness)
291+
if (!is_fully_replicated || is_reducer) {
292+
// Prefer non-replicated write buffers, but also allow replicated
293+
// reducers
294+
if (!source_buffer.defined() ||
295+
(!is_fully_replicated && source_buffer.defined())) {
296+
source_buffer = buffer;
297+
}
298+
} else if (!replicated_write_buffer.defined()) {
299+
// Keep fully replicated NON-reducer write buffer as backup
300+
replicated_write_buffer = buffer;
301+
DLOG(INFO) << "Found fully replicated non-reducer write buffer "
302+
<< buffer << " as backup for loop layout inference"
303+
<< '\n';
304+
}
286305
} else {
287306
// Keep the buffer with largest number of indices
288307
// (which means the inference based on that buffer is more accurate)
@@ -308,6 +327,18 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
308327
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
309328
DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `"
310329
<< buffer << "` of layout " << src_layout->DebugOutput() << '\n';
330+
331+
// Check if this buffer is a reducer
332+
bool is_reducer = reducer_info_map_.count(buffer->data);
333+
334+
// Defensive check: warn if attempting to infer from fully replicated buffer
335+
// But allow it for reducers (they need replication for correctness)
336+
if (src_layout->IsCompletedReplicated() && !is_reducer) {
337+
DLOG(WARNING)
338+
<< "Attempting to infer loop layout from fully replicated buffer "
339+
<< buffer << ", this may cause incorrect replication propagation";
340+
}
341+
311342
Fragment result;
312343
if (IsCommonAccessIndice(buffer)) {
313344
result = src_layout;
@@ -318,6 +349,34 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
318349
PrimExpr loop_var_to_thread =
319350
src_layout->ForwardThread(indice_map_[buffer], rep);
320351
loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
352+
353+
// Check if loop_var_to_thread only depends on rep (not on loop_vars)
354+
// This indicates a fully replicated buffer that shouldn't determine loop
355+
// layout UNLESS it's a reducer (reducers need full replication)
356+
bool uses_loop_var = false;
357+
PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
358+
if (auto var = objref.as<Var>()) {
359+
for (const auto &loop_var : loop_vars_) {
360+
if (var->same_as(loop_var->var)) {
361+
uses_loop_var = true;
362+
break;
363+
}
364+
}
365+
}
366+
});
367+
368+
if (!uses_loop_var && !is_reducer) {
369+
// loop_var_to_thread only depends on rep, not on loop variables
370+
// And it's not a reducer, so this is likely an index offset case
371+
DLOG(WARNING) << "Buffer " << buffer
372+
<< " is fully replicated (not a reducer). "
373+
<< "Cannot use it to infer loop layout. "
374+
<< "loop_var_to_thread = " << loop_var_to_thread
375+
<< " (only depends on rep)";
376+
throw LayoutConflictException("Buffer is fully replicated and cannot "
377+
"be used for layout inference");
378+
}
379+
321380
PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
322381
if (auto opt_var = objref.as<Var>();
323382
opt_var && inner_vars_.count(*opt_var)) {
@@ -334,6 +393,13 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
334393
<< result->DebugOutput() << '\n';
335394
return result;
336395
};
396+
397+
// Try to infer loop layout from buffers in order of preference:
398+
// 1. Non-replicated write buffer (most reliable)
399+
// 2. Non-replicated read buffer
400+
// 3. Fully replicated write buffer (backup, may cause issues)
401+
// 4. Free inference mode (no source buffer)
402+
337403
if (source_buffer.defined()) {
338404
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
339405
} else if (level == InferLevel::kFree) {
@@ -388,7 +454,25 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
388454
auto rep = inv->Forward(fwd).back();
389455
AddPredicate(EQ(rep, 0));
390456
}
391-
} else {
457+
} else if (replicated_write_buffer.defined()) {
458+
// Backup: try to use fully replicated write buffer
459+
// This may cause replication propagation, but it's better than failing
460+
DLOG(WARNING) << "Using fully replicated buffer "
461+
<< replicated_write_buffer
462+
<< " for loop layout inference as no other source buffer "
463+
"is available";
464+
try {
465+
loop_layout_ = compute_loop_layout_from_buffer(replicated_write_buffer);
466+
} catch (const LayoutConflictException &e) {
467+
// If fails, fall back to free mode
468+
DLOG(WARNING) << "Failed to infer from replicated buffer: " << e.what()
469+
<< ". Falling back to free mode";
470+
replicated_write_buffer = Buffer(); // Clear to trigger free mode below
471+
}
472+
}
473+
474+
if (!loop_layout_.defined()) {
475+
// No source buffer available, use free mode inference
392476
// Vectorize Size must be aware of the buffer_remap
393477
// As the pass will do post processing to the layout
394478
auto maybe_remapped_root_ =

0 commit comments

Comments
 (0)