@@ -36,21 +36,18 @@ namespace tir {
3636class ScriptCompleter : public StmtMutator {
3737 public:
3838 explicit ScriptCompleter (Map<Var, Buffer>* buffer_var_map) : buffer_var_map_(buffer_var_map) {}
39- /* ! \brief Whether the stmt contains at least one block. */
40- bool contains_block = false ;
4139
4240 private:
4341 Map<Var, Buffer>* buffer_var_map_;
44- Stmt VisitStmt_ (const BlockRealizeNode* op) override {
45- contains_block = true ;
42+ Stmt VisitStmt_ (const BlockRealizeNode* op) final {
4643 for (const PrimExpr& value : op->iter_values ) {
4744 CHECK (value.dtype ().is_int ())
4845 << " BlockRealize iter_value expected a IntImm, but got " << value.dtype ();
4946 }
5047 return StmtMutator::VisitStmt_ (op);
5148 }
5249
53- Stmt VisitStmt_ (const BlockNode* op) override {
50+ Stmt VisitStmt_ (const BlockNode* op) final {
5451 // Buffers allocated in the block can be accessed by its body.
5552 for (const auto & alloc_buffer : op->alloc_buffers ) {
5653 buffer_var_map_->Set (alloc_buffer->data , alloc_buffer);
@@ -59,7 +56,12 @@ class ScriptCompleter : public StmtMutator {
5956 const Buffer& target_buffer = match_buffer->buffer ;
6057 buffer_var_map_->Set (target_buffer->data , target_buffer);
6158 }
59+
60+ bool is_root_block = this ->is_root_block_ ;
61+ this ->is_root_block_ = false ;
6262 Block block = Downcast<Block>(StmtMutator::VisitStmt_ (op));
63+ this ->is_root_block_ = is_root_block;
64+
6365 // Remove buffers allocated inside block to detect its access region
6466 for (const auto & alloc_buffer : op->alloc_buffers ) {
6567 buffer_var_map_->erase (alloc_buffer->data );
@@ -85,15 +87,19 @@ class ScriptCompleter : public StmtMutator {
8587 << " ValueError: Can not auto detect buffer access region from tir.Load, tir.Store or "
8688 " direct access by buffer data. Please annotation the access region manually" ;
8789 auto n = CopyOnWrite (block.operator ->());
88- if (mask & 1 ) n->reads = reads;
89- if (mask & 2 ) n->writes = writes;
90+ if (!is_root_block) {
91+ if (mask & 1 ) n->reads = reads;
92+ if (mask & 2 ) n->writes = writes;
93+ }
9094 n->annotations = op->annotations ;
9195 n->annotations .erase (attr::script_parsing_detect_access);
9296 return Block (n);
9397 } else {
9498 return std::move (block);
9599 }
96100 }
101+
102+ bool is_root_block_ = true ;
97103};
98104
99105PrimFunc ScriptComplete (PrimFunc func, const Array<Buffer>& root_allocates) {
0 commit comments