1111#include < tvm/tir/transform.h>
1212
1313#include < utility>
14+ #include < unordered_set>
1415
1516#include " ../op/builtin.h"
1617
@@ -153,20 +154,47 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
153154 }
154155
155156 std::unordered_set<const BufferNode *> consumer_used, producer_used;
157+ std::unordered_map<const BufferNode *, size_t > first_write_index;
158+ std::unordered_map<const BufferNode *, size_t > last_read_index;
156159 for (size_t i = 0 ; i < seq_stmt.size (); i++) {
157- if (roles[i] == Role::kProducer ) {
158- for (BufferRegion br : writes[i])
160+ // Warp-specialized lowering may tag stages as kBoth, so treat them as
161+ // producing and consuming to keep the liveness information intact.
162+ bool is_producer = roles[i] == Role::kProducer || roles[i] == Role::kBoth ;
163+ bool is_consumer = roles[i] == Role::kConsumer || roles[i] == Role::kBoth ;
164+ if (is_producer) {
165+ for (BufferRegion br : writes[i]) {
159166 producer_used.insert (br->buffer .get ());
160- } else {
161- for (BufferRegion br : reads[i])
167+ }
168+ }
169+ if (is_consumer) {
170+ for (BufferRegion br : reads[i]) {
162171 consumer_used.insert (br->buffer .get ());
172+ }
173+ }
174+ for (BufferRegion br : writes[i]) {
175+ const BufferNode *buf = br->buffer .get ();
176+ if (!first_write_index.count (buf)) {
177+ first_write_index[buf] = i;
178+ }
179+ }
180+ for (BufferRegion br : reads[i]) {
181+ last_read_index[br->buffer .get ()] = i;
163182 }
164183 }
165184 Array<Buffer> versioned_buffers;
166185 for (Buffer buffer : scoped_buffers) {
167186 if (consumer_used.count (buffer.get ()) &&
168187 producer_used.count (buffer.get ())) {
169188 versioned_buffers.push_back (buffer);
189+ continue ;
190+ }
191+ // Fallback: if we saw a write before a later read, the buffer spans
192+ // multiple stages even if role classification missed one side.
193+ auto it_w = first_write_index.find (buffer.get ());
194+ auto it_r = last_read_index.find (buffer.get ());
195+ if (it_w != first_write_index.end () && it_r != last_read_index.end () &&
196+ it_w->second < it_r->second ) {
197+ versioned_buffers.push_back (buffer);
170198 }
171199 }
172200 return versioned_buffers;
@@ -197,31 +225,112 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
197225 }
198226 }
199227 block.CopyOnWrite ()->alloc_buffers = std::move (alloc_buffers);
228+ // Record the updated alloc list to recover buffers whose LCA is the block.
229+ block_alloc_buffers_[op->block .get ()] = block->alloc_buffers ;
200230 block_realize.CopyOnWrite ()->block = block;
201231 return block_realize;
202232 }
203233
234+ Stmt VisitStmt_ (const BlockNode *op) final {
235+ stmt_stack_.push_back (op);
236+ Stmt stmt = StmtExprMutator::VisitStmt_ (op);
237+ stmt_stack_.pop_back ();
238+ return stmt;
239+ }
240+
204241 Stmt VisitStmt_ (const ForNode *op) final {
242+ stmt_stack_.push_back (op);
205243 loop_stack_.emplace_back (op->loop_var , op->extent );
206244 auto num_stages_anno = op->annotations .Get (" num_stages" );
207245 if (!num_stages_anno) {
208246 auto for_node = StmtExprMutator::VisitStmt_ (op);
209247 loop_stack_.pop_back ();
248+ stmt_stack_.pop_back ();
210249 return for_node;
211250 }
212251
213252 ICHECK (num_stages_anno->as <IntImmNode>());
214253 int num_stages = static_cast <int >(num_stages_anno->as <IntImmNode>()->value );
215254
216- const SeqStmtNode *pipeline_body_seq = op->body .as <SeqStmtNode>();
217- CHECK (pipeline_body_seq) << " ValueError: The body of the software pipeline "
218- " should be SeqStmt, got "
219- << op->body ->GetTypeKey ();
255+ Stmt pipeline_body_root{nullptr };
256+ if (const auto *realize = op->body .as <BlockRealizeNode>()) {
257+ const auto &block = realize->block ;
258+ for (const auto &buffer : block->alloc_buffers ) {
259+ ICHECK (buffer->IsInstance <BufferNode>());
260+ buffer_data_to_buffer_.Set (buffer->data , buffer);
261+ }
262+ pipeline_body_root = block->body ;
263+ } else {
264+ pipeline_body_root = op->body ;
265+ }
220266
221- Array<Buffer> scoped_buffers = {};
267+ const SeqStmtNode *pipeline_body_seq = nullptr ;
268+ {
269+ // Traverse trivial wrappers (let/if) to find the actual SeqStmt body.
270+ Stmt current = pipeline_body_root;
271+ while (true ) {
272+ if (const auto *seq_stmt = current.as <SeqStmtNode>()) {
273+ pipeline_body_seq = seq_stmt;
274+ break ;
275+ }
276+ if (const auto *if_then_else = current.as <IfThenElseNode>()) {
277+ ICHECK (!if_then_else->else_case .defined ())
278+ << " MultiVersionBuffer: Can't handle the body of the loop "
279+ " because the IfThenElse node has an else branch" ;
280+ current = if_then_else->then_case ;
281+ continue ;
282+ }
283+ if (const auto *let_stmt = current.as <LetStmtNode>()) {
284+ current = let_stmt->body ;
285+ continue ;
286+ }
287+ LOG (FATAL)
288+ << " MultiVersionBuffer: Can't handle the body of the loop because "
289+ << " it is not a SeqStmt, IfThenElse without else, "
290+ << " or LetStmt wrapping them, but got "
291+ << current->GetTypeKey ();
292+ }
293+ }
294+ ICHECK (pipeline_body_seq != nullptr );
295+
296+ Array<Buffer> scoped_buffers;
297+ std::unordered_set<const BufferNode *> seen;
222298 for (auto [buffer, stmt] : buffer_lca_) {
223- if (stmt.defined () && stmt.value ().get () == op)
299+ if (!stmt.defined ())
300+ continue ;
301+ const StmtNode *lca = stmt.value ().get ();
302+ bool in_scope = false ;
303+ for (const StmtNode *ancestor : stmt_stack_) {
304+ if (ancestor == lca) {
305+ in_scope = true ;
306+ break ;
307+ }
308+ }
309+ if (!in_scope)
310+ continue ;
311+ // Only double-buffer shared allocations; locals do not need versioning.
312+ auto scope = buffer.scope ();
313+ if (!(scope == " shared" || scope == " shared.dyn" ))
314+ continue ;
315+ if (seen.insert (buffer.get ()).second ) {
224316 scoped_buffers.push_back (buffer);
317+ }
318+ }
319+ for (auto it = stmt_stack_.rbegin (); it != stmt_stack_.rend (); ++it) {
320+ if (!(*it)->IsInstance <BlockNode>())
321+ continue ;
322+ const auto *block = static_cast <const BlockNode *>(*it);
323+ auto map_it = block_alloc_buffers_.find (block);
324+ if (map_it == block_alloc_buffers_.end ())
325+ continue ;
326+ for (const Buffer &buffer : map_it->second ) {
327+ auto scope = buffer.scope ();
328+ if (!(scope == " shared" || scope == " shared.dyn" ))
329+ continue ;
330+ if (seen.insert (buffer.get ()).second ) {
331+ scoped_buffers.push_back (buffer);
332+ }
333+ }
225334 }
226335
227336 Array<Buffer> versioned_buffers =
@@ -240,6 +349,7 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
240349 version_index_ = FloorMod (linear_index, num_stages);
241350 auto for_node = StmtExprMutator::VisitStmt_ (op);
242351 loop_stack_.pop_back ();
352+ stmt_stack_.pop_back ();
243353
244354 return for_node;
245355 }
@@ -312,9 +422,13 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
312422
313423 PrimExpr version_index_;
314424 std::vector<std::pair<Var, PrimExpr>> loop_stack_;
425+ // Track ancestor statements to query whether an LCA is inside the current loop.
426+ std::vector<const StmtNode *> stmt_stack_;
315427 Map<Var, Buffer> buffer_data_to_buffer_;
316428 Map<Buffer, Optional<Stmt>> buffer_lca_;
317429 Map<Buffer, Buffer> buffer_remap_;
430+ // Remember each block's alloc list so the loop can see buffers defined in parents.
431+ std::unordered_map<const BlockNode *, Array<Buffer>> block_alloc_buffers_;
318432};
319433
320434using namespace tir ::transform;
0 commit comments