@@ -154,23 +154,27 @@ Block MakeReverseCacheStage(const BufferRegion& cache_region, ReverseCacheTouche
154154 for (const Range& range : cache_region->region ) {
155155 if (touched_info->read ) {
156156 read_access_indices.push_back (Substitute (range->min , var_map));
157+ read_access_region.push_back (Range::FromMinExtent (read_access_indices.back (), Integer (1 )));
157158 } else {
158159 write_access_indices.push_back (Substitute (range->min , var_map));
160+ write_access_region.push_back (Range::FromMinExtent (write_access_indices.back (), Integer (1 )));
159161 }
160162 }
161163 for (const IterVar& block_var : block_vars) {
162164 if (touched_info->read ) {
163165 write_access_indices.push_back (block_var->var );
166+ write_access_region.push_back (Range::FromMinExtent (write_access_indices.back (), Integer (1 )));
164167 } else {
165168 read_access_indices.push_back (block_var->var );
169+ read_access_region.push_back (Range::FromMinExtent (read_access_indices.back (), Integer (1 )));
166170 }
167171 }
168172
169173 // Create New Block
170174 Block block (
171175 /* iter_vars*/ std::move (block_vars),
172- /* reads=*/ {},
173- /* writes=*/ {},
176+ /* reads=*/ {BufferRegion (info-> read_buffer , read_access_region) },
177+ /* writes=*/ {BufferRegion (info-> write_buffer , write_access_region) },
174178 /* name_hint*/ cache_region->buffer ->name + " _" + storage_scope,
175179 /* body=*/
176180 BufferStore (info->write_buffer , BufferLoad (info->read_buffer , read_access_indices),
@@ -187,8 +191,8 @@ Block MakeReverseCacheStage(const BufferRegion& cache_region, ReverseCacheTouche
187191 // Create surrounding loops
188192 for (size_t i = loop_vars.size (); i >= 1 ; --i) {
189193 body = For (/* loop_var=*/ loop_vars[i - 1 ],
190- /* min=*/ touched_info->loop_ranges [i]->min ,
191- /* extent=*/ touched_info->loop_ranges [i]->extent ,
194+ /* min=*/ touched_info->loop_ranges [i - 1 ]->min ,
195+ /* extent=*/ touched_info->loop_ranges [i - 1 ]->extent ,
192196 /* kind=*/ ForKind::kSerial ,
193197 /* body=*/ body);
194198 }
@@ -487,6 +491,112 @@ class CacheLocDetector : public StmtVisitor {
487491 int loc_pos_{-1 };
488492};
489493
494+ /* ! \brief Mutator for ReverseCacheRead. */
495+ class ReverseCacheReadRewriter : public StmtExprMutator {
496+ public:
497+ /* !
498+ * \brief Rewrite the AST and add a cache_read stage with the information provided.
499+ * \param scope_sref The parent scope of this mutation.
500+ * \param info The cache stage information.
501+ * \param touched_info The reverse cache touched information.
502+ * \return The new AST rooting at the original parent scope.
503+ */
504+ static Stmt Rewrite (const StmtSRef& scope_sref, CacheStageInfo* info,
505+ ReverseCacheTouchedInfo* touched_info) {
506+ ReverseCacheReadRewriter rewriter (scope_sref, info, touched_info);
507+ return rewriter (GetRef<Stmt>(scope_sref->stmt ));
508+ }
509+
510+ private:
511+ explicit ReverseCacheReadRewriter (const StmtSRef& scope_sref, CacheStageInfo* info,
512+ ReverseCacheTouchedInfo* touched_info)
513+ : scope_sref_(scope_sref), info_(info) {
514+ for (const IterVar& iter_var : touched_info->block_vars ) {
515+ new_indices_.push_back (iter_var->var );
516+ }
517+ }
518+
519+ Stmt VisitStmt_ (const ForNode* loop) final {
520+ Stmt stmt = StmtMutator::VisitStmt_ (loop);
521+ // Check the insertion point
522+ if (loop == info_->loc_sref ->stmt ) {
523+ // Insert cache stage into the loop if it is the right place
524+ ObjectPtr<ForNode> n = make_object<ForNode>(*stmt.as <ForNode>());
525+ n->body = InsertCacheStage (n->body , info_->loc_pos , info_->cache_stage );
526+ stmt = Stmt (n);
527+ }
528+ return stmt;
529+ }
530+
531+ Stmt VisitStmt_ (const BlockNode* block) final {
532+ Block old_stmt = GetRef<Block>(block);
533+ if (block != scope_sref_->stmt &&
534+ GetBufferRegionFromBuffer (block->writes , info_->read_buffer ).defined ()) {
535+ return std::move (old_stmt);
536+ }
537+ // Mutate the body
538+ Block stmt = Downcast<Block>(StmtMutator::VisitStmt_ (block));
539+ // Check the insertion point
540+ if (block == info_->loc_sref ->stmt ) {
541+ // Insert cache stage into the block if it is the right place
542+ ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as <BlockNode>());
543+ n->body = InsertCacheStage (n->body , info_->loc_pos , info_->cache_stage );
544+ stmt = Block (n);
545+ }
546+ // Check if it is the block corresponding to the parent scope
547+ if (block == scope_sref_->stmt ) {
548+ // If so, put buffer allocation on the parent scope
549+ ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as <BlockNode>());
550+ n->alloc_buffers .push_back (info_->alloc );
551+ stmt = Block (n);
552+ } else {
553+ // Otherwise, update read regions and match_buffers
554+ Array<BufferRegion> reads;
555+ for (const BufferRegion& buf_region : block->reads ) {
556+ if (buf_region->buffer .same_as (info_->read_buffer )) {
557+ Array<Range> region;
558+ for (const PrimExpr index : new_indices_) {
559+ region.push_back (Range::FromMinExtent (index, Integer (1 )));
560+ }
561+ reads.push_back (BufferRegion (info_->write_buffer , region));
562+ } else {
563+ reads.push_back (buf_region);
564+ }
565+ }
566+
567+ CHECK_EQ (block->match_buffers .size (), 0 ) << " Not supported yet." ;
568+ if (!reads.same_as (block->reads )) {
569+ ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as <BlockNode>());
570+ n->reads = std::move (reads);
571+ stmt = Block (n);
572+ }
573+ }
574+ info_->block_reuse .Set (old_stmt, stmt);
575+ return std::move (stmt);
576+ }
577+
578+ PrimExpr VisitExpr_ (const VarNode* op) final {
579+ if (op == info_->read_buffer ->data .get ()) {
580+ return info_->write_buffer ->data ;
581+ }
582+ return GetRef<PrimExpr>(op);
583+ }
584+
585+ PrimExpr VisitExpr_ (const BufferLoadNode* load) final {
586+ if (load->buffer .same_as (info_->read_buffer )) {
587+ ObjectPtr<BufferLoadNode> n = make_object<BufferLoadNode>(*load);
588+ n->buffer = info_->write_buffer ;
589+ n->indices = new_indices_;
590+ return PrimExpr (n);
591+ }
592+ return ExprMutator::VisitExpr_ (load);
593+ }
594+
595+ const StmtSRef& scope_sref_;
596+ CacheStageInfo* info_;
597+ Array<PrimExpr> new_indices_;
598+ };
599+
490600/* ! \brief Mutator for CacheRead. */
491601class CacheReadRewriter : public StmtExprMutator {
492602 public:
@@ -881,35 +991,25 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re
881991 // Step 2. Create CacheStageInfo
882992 CacheStageInfo info;
883993 info.read_buffer = read_buffer;
884- // Create the corresponding buffer to be written, i.e. result of cache_read
885- info.write_buffer = WithScope (read_buffer, storage_scope);
886- // Create the corresponding buffer allocation
887- info.alloc = info.write_buffer ;
888994 info.annotations = block->annotations ;
889995
890996 // Step 3. Update cache stage info.
891- BufferRegion cache_region{nullptr };
997+ Optional<BufferRegion> maybe_region = GetBufferRegionFromBuffer (block->reads , read_buffer);
998+ ICHECK (maybe_region.defined ()) << read_buffer
999+ << " should appear in the block's read region: " << block->reads ;
1000+ BufferRegion cache_region = maybe_region.value ();
8921001 if (Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock (self, scope_sref, read_buffer)) {
8931002 // Case 1. The buffer is written inside the block.
8941003 StmtSRef write_block_sref = _write_block_sref.value ();
8951004 const BlockNode* write_block = TVM_SREF_TO_BLOCK (write_block, write_block_sref);
8961005 // Find the producing region
897- BufferRegion region = GetBufferRegionFromBuffer (write_block->writes , read_buffer).value ();
8981006 StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent );
899-
9001007 // Detect insert position
9011008 CacheLocDetector::Detect (self, write_block_sref, scope_sref, &info);
902- cache_region = RelaxBufferRegion (self, region, write_block_sref, parent_sref, info.loc_sref );
9031009 } else {
9041010 // Case 2. The buffer is the input block for the scope.
9051011 info.loc_sref = scope_sref;
9061012 info.loc_pos = 0 ;
907- if (Optional<BufferRegion> region =
908- GetBufferRegionFromBuffer (scope_block->reads , read_buffer)) {
909- cache_region = region.value ();
910- } else {
911- cache_region = BufferRegion::FullRegion (read_buffer);
912- }
9131013 }
9141014
9151015 // Step 4. Create CacheTouchedInfo
@@ -918,6 +1018,7 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re
9181018 // Step 5. Update CacheTouchedInfo
9191019 touched_info.read = true ;
9201020 VarCollector collector;
1021+ Array<PrimExpr> new_shape;
9211022 for (const Range& range : cache_region->region ) {
9221023 collector (range->min );
9231024 }
@@ -926,6 +1027,7 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re
9261027 IterVar block_var = block->iter_vars [i];
9271028 if (collector.touched .count (block_var->var .get ())) {
9281029 touched_info.block_vars .push_back (block_var);
1030+ new_shape.push_back (block_var->dom ->min + block_var->dom ->extent );
9291031 touched_info.iter_values .push_back (realize->iter_values [i]);
9301032 collector (touched_info.iter_values .back ());
9311033 }
@@ -938,11 +1040,24 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re
9381040 }
9391041 }
9401042
1043+ // Create write buffer.
1044+ ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*read_buffer.get ());
1045+ ObjectPtr<VarNode> new_var = make_object<VarNode>(*read_buffer->data .get ());
1046+ const auto * ptr_type = TVM_TYPE_AS (ptr_type, read_buffer->data ->type_annotation , PointerTypeNode);
1047+ new_var->type_annotation = PointerType (ptr_type->element_type , storage_scope);
1048+ new_buffer->data = Var (new_var->name_hint + " _" + storage_scope, new_var->type_annotation );
1049+ new_buffer->name = read_buffer->name + " _" + storage_scope;
1050+ new_buffer->shape = new_shape;
1051+
1052+ info.write_buffer = Buffer (new_buffer);
1053+ info.alloc = info.write_buffer ;
1054+
9411055 // Step 6. Making new cache stage block and rewrite readers.
9421056 Block cache_read_stage = MakeReverseCacheStage (/* cache_region=*/ cache_region,
9431057 /* touched_info*/ &touched_info, /* info=*/ &info,
9441058 /* storage_scope=*/ storage_scope);
945- Stmt new_scope = CacheReadRewriter::Rewrite (/* scope_sref=*/ scope_sref, /* info=*/ &info);
1059+ Stmt new_scope = ReverseCacheReadRewriter::Rewrite (/* scope_sref=*/ scope_sref, /* info=*/ &info,
1060+ /* touched_info=*/ &touched_info);
9461061
9471062 // Step 7. Replacing and updating flags.
9481063 self->Replace (scope_sref, new_scope, info.block_reuse );
@@ -956,62 +1071,7 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re
9561071
9571072StmtSRef ReverseCacheWrite (ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
9581073 const String& storage_scope) {
959- /* !
960- * Check:
961- * - The index is in the array of block reading region
962- * - There is only one block who write the buffer in the scope
963- *
964- * Mutate:
965- * - Allocate new cache buffer under the current scope.
966- * - Find the lowest ancestor of the block and ANY ONE of the producer blocks.
967- * - Copy the buffer with the consumed region.
968- */
969-
970- // Step 0. Check the input storage scope.
971- CheckStorageScope (self, storage_scope);
972-
973- // Step 1. Checking index, getting the target buffer and the parent scope
974- const BlockNode* block = TVM_SREF_TO_BLOCK (block, block_sref);
975- Buffer write_buffer =
976- GetNthAccessBuffer (self, GetRef<Block>(block), write_buffer_index, /* is_write=*/ true );
977- StmtSRef scope_sref = GetScopeRoot (self, block_sref, /* require_stage_pipeline=*/ true );
978-
979- // Step 2. Creating CacheStageInfo
980- CacheStageInfo info;
981- info.read_buffer = WithScope (write_buffer, storage_scope);
982- // Create the corresponding buffer to be written, i.e. result of cache_write
983- info.write_buffer = write_buffer;
984- // Create the corresponding buffer allocation
985- info.alloc = info.read_buffer ;
986- info.annotations = block->annotations ;
987-
988- // Step 3. Check the only writer block.
989- if (!IsHorizontalFuse (self)) {
990- ICHECK_EQ (block_sref.get (), GetOnlyWriteBlock (self, scope_sref, write_buffer).get ());
991- }
992-
993- // Step 4. Find the producing region and insert position
994- BufferRegion region = GetBufferRegionFromBuffer (block->writes , write_buffer).value ();
995- StmtSRef parent_sref = GetRef<StmtSRef>(block_sref->parent );
996- // Detect insert position
997- CacheLocDetector::Detect (self, block_sref, scope_sref, &info);
998- BufferRegion cache_region =
999- RelaxBufferRegion (self, region, block_sref, parent_sref, info.loc_sref );
1000-
1001- // Step 5. Making new cache stage block and rewrite readers.
1002- Block cache_write_stage = MakeCacheStage (/* cache_region=*/ cache_region, /* info=*/ &info,
1003- /* storage_scope=*/ storage_scope);
1004- Stmt new_scope = CacheWriteRewriter::Rewrite (/* scope_sref=*/ scope_sref,
1005- /* writer_block_sref=*/ block_sref, /* info=*/ &info);
1006-
1007- // Step 6. Replacing and updating flags.
1008- self->Replace (scope_sref, new_scope, info.block_reuse );
1009- StmtSRef result_block_sref = self->stmt2ref .at (cache_write_stage.get ());
1010- BlockInfo& block_info = self->block_info [result_block_sref];
1011- block_info.affine_binding = CalculateAffineFlag (self, result_block_sref);
1012- block_info.region_cover = true ;
1013- block_info.scope ->stage_pipeline = true ;
1014- return result_block_sref;
1074+ LOG (FATAL) << " Not implemented yet." ;
10151075}
10161076
10171077/* ******* Instruction Registration ********/
0 commit comments