Skip to content

Commit d0648e5

Browse files
committed
fix for wgmma pipeline with let binding
1 parent 0da83a8 commit d0648e5

File tree

1 file changed

+124
-10
lines changed

1 file changed

+124
-10
lines changed

src/transform/multi_version_buffer_rewriter.cc

Lines changed: 124 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <tvm/tir/transform.h>
1212

1313
#include <utility>
14+
#include <unordered_set>
1415

1516
#include "../op/builtin.h"
1617

@@ -153,20 +154,47 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
153154
}
154155

155156
std::unordered_set<const BufferNode *> consumer_used, producer_used;
157+
std::unordered_map<const BufferNode *, size_t> first_write_index;
158+
std::unordered_map<const BufferNode *, size_t> last_read_index;
156159
for (size_t i = 0; i < seq_stmt.size(); i++) {
157-
if (roles[i] == Role::kProducer) {
158-
for (BufferRegion br : writes[i])
160+
// Warp-specialized lowering may tag stages as kBoth, so treat them as
161+
// producing and consuming to keep the liveness information intact.
162+
bool is_producer = roles[i] == Role::kProducer || roles[i] == Role::kBoth;
163+
bool is_consumer = roles[i] == Role::kConsumer || roles[i] == Role::kBoth;
164+
if (is_producer) {
165+
for (BufferRegion br : writes[i]) {
159166
producer_used.insert(br->buffer.get());
160-
} else {
161-
for (BufferRegion br : reads[i])
167+
}
168+
}
169+
if (is_consumer) {
170+
for (BufferRegion br : reads[i]) {
162171
consumer_used.insert(br->buffer.get());
172+
}
173+
}
174+
for (BufferRegion br : writes[i]) {
175+
const BufferNode *buf = br->buffer.get();
176+
if (!first_write_index.count(buf)) {
177+
first_write_index[buf] = i;
178+
}
179+
}
180+
for (BufferRegion br : reads[i]) {
181+
last_read_index[br->buffer.get()] = i;
163182
}
164183
}
165184
Array<Buffer> versioned_buffers;
166185
for (Buffer buffer : scoped_buffers) {
167186
if (consumer_used.count(buffer.get()) &&
168187
producer_used.count(buffer.get())) {
169188
versioned_buffers.push_back(buffer);
189+
continue;
190+
}
191+
// Fallback: if we saw a write before a later read, the buffer spans
192+
// multiple stages even if role classification missed one side.
193+
auto it_w = first_write_index.find(buffer.get());
194+
auto it_r = last_read_index.find(buffer.get());
195+
if (it_w != first_write_index.end() && it_r != last_read_index.end() &&
196+
it_w->second < it_r->second) {
197+
versioned_buffers.push_back(buffer);
170198
}
171199
}
172200
return versioned_buffers;
@@ -197,31 +225,112 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
197225
}
198226
}
199227
block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
228+
// Record the updated alloc list to recover buffers whose LCA is the block.
229+
block_alloc_buffers_[op->block.get()] = block->alloc_buffers;
200230
block_realize.CopyOnWrite()->block = block;
201231
return block_realize;
202232
}
203233

234+
Stmt VisitStmt_(const BlockNode *op) final {
235+
stmt_stack_.push_back(op);
236+
Stmt stmt = StmtExprMutator::VisitStmt_(op);
237+
stmt_stack_.pop_back();
238+
return stmt;
239+
}
240+
204241
Stmt VisitStmt_(const ForNode *op) final {
242+
stmt_stack_.push_back(op);
205243
loop_stack_.emplace_back(op->loop_var, op->extent);
206244
auto num_stages_anno = op->annotations.Get("num_stages");
207245
if (!num_stages_anno) {
208246
auto for_node = StmtExprMutator::VisitStmt_(op);
209247
loop_stack_.pop_back();
248+
stmt_stack_.pop_back();
210249
return for_node;
211250
}
212251

213252
ICHECK(num_stages_anno->as<IntImmNode>());
214253
int num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
215254

216-
const SeqStmtNode *pipeline_body_seq = op->body.as<SeqStmtNode>();
217-
CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
218-
"should be SeqStmt, got "
219-
<< op->body->GetTypeKey();
255+
Stmt pipeline_body_root{nullptr};
256+
if (const auto *realize = op->body.as<BlockRealizeNode>()) {
257+
const auto &block = realize->block;
258+
for (const auto &buffer : block->alloc_buffers) {
259+
ICHECK(buffer->IsInstance<BufferNode>());
260+
buffer_data_to_buffer_.Set(buffer->data, buffer);
261+
}
262+
pipeline_body_root = block->body;
263+
} else {
264+
pipeline_body_root = op->body;
265+
}
220266

221-
Array<Buffer> scoped_buffers = {};
267+
const SeqStmtNode *pipeline_body_seq = nullptr;
268+
{
269+
// Traverse trivial wrappers (let/if) to find the actual SeqStmt body.
270+
Stmt current = pipeline_body_root;
271+
while (true) {
272+
if (const auto *seq_stmt = current.as<SeqStmtNode>()) {
273+
pipeline_body_seq = seq_stmt;
274+
break;
275+
}
276+
if (const auto *if_then_else = current.as<IfThenElseNode>()) {
277+
ICHECK(!if_then_else->else_case.defined())
278+
<< "MultiVersionBuffer: Can't handle the body of the loop "
279+
"because the IfThenElse node has an else branch";
280+
current = if_then_else->then_case;
281+
continue;
282+
}
283+
if (const auto *let_stmt = current.as<LetStmtNode>()) {
284+
current = let_stmt->body;
285+
continue;
286+
}
287+
LOG(FATAL)
288+
<< "MultiVersionBuffer: Can't handle the body of the loop because "
289+
<< "it is not a SeqStmt, IfThenElse without else, "
290+
<< "or LetStmt wrapping them, but got "
291+
<< current->GetTypeKey();
292+
}
293+
}
294+
ICHECK(pipeline_body_seq != nullptr);
295+
296+
Array<Buffer> scoped_buffers;
297+
std::unordered_set<const BufferNode *> seen;
222298
for (auto [buffer, stmt] : buffer_lca_) {
223-
if (stmt.defined() && stmt.value().get() == op)
299+
if (!stmt.defined())
300+
continue;
301+
const StmtNode *lca = stmt.value().get();
302+
bool in_scope = false;
303+
for (const StmtNode *ancestor : stmt_stack_) {
304+
if (ancestor == lca) {
305+
in_scope = true;
306+
break;
307+
}
308+
}
309+
if (!in_scope)
310+
continue;
311+
// Only double-buffer shared allocations; locals do not need versioning.
312+
auto scope = buffer.scope();
313+
if (!(scope == "shared" || scope == "shared.dyn"))
314+
continue;
315+
if (seen.insert(buffer.get()).second) {
224316
scoped_buffers.push_back(buffer);
317+
}
318+
}
319+
for (auto it = stmt_stack_.rbegin(); it != stmt_stack_.rend(); ++it) {
320+
if (!(*it)->IsInstance<BlockNode>())
321+
continue;
322+
const auto *block = static_cast<const BlockNode *>(*it);
323+
auto map_it = block_alloc_buffers_.find(block);
324+
if (map_it == block_alloc_buffers_.end())
325+
continue;
326+
for (const Buffer &buffer : map_it->second) {
327+
auto scope = buffer.scope();
328+
if (!(scope == "shared" || scope == "shared.dyn"))
329+
continue;
330+
if (seen.insert(buffer.get()).second) {
331+
scoped_buffers.push_back(buffer);
332+
}
333+
}
225334
}
226335

227336
Array<Buffer> versioned_buffers =
@@ -240,6 +349,7 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
240349
version_index_ = FloorMod(linear_index, num_stages);
241350
auto for_node = StmtExprMutator::VisitStmt_(op);
242351
loop_stack_.pop_back();
352+
stmt_stack_.pop_back();
243353

244354
return for_node;
245355
}
@@ -312,9 +422,13 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
312422

313423
PrimExpr version_index_;
314424
std::vector<std::pair<Var, PrimExpr>> loop_stack_;
425+
// Track ancestor statements to query whether an LCA is inside the current loop.
426+
std::vector<const StmtNode *> stmt_stack_;
315427
Map<Var, Buffer> buffer_data_to_buffer_;
316428
Map<Buffer, Optional<Stmt>> buffer_lca_;
317429
Map<Buffer, Buffer> buffer_remap_;
430+
// Remember each block's alloc list so the loop can see buffers defined in parents.
431+
std::unordered_map<const BlockNode *, Array<Buffer>> block_alloc_buffers_;
318432
};
319433

320434
using namespace tir::transform;

0 commit comments

Comments
 (0)