Skip to content

Commit 20929fa

Browse files
authored
[Feature] Reverse cache-read/write (Part 1/2) (#43)
* init * upd test * place contiguous read blocks correctly
1 parent 99e3799 commit 20929fa

File tree

2 files changed

+137
-76
lines changed

2 files changed

+137
-76
lines changed

src/tir/schedule/primitive/cache_read_write.cc

Lines changed: 135 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -154,23 +154,27 @@ Block MakeReverseCacheStage(const BufferRegion& cache_region, ReverseCacheTouche
154154
for (const Range& range : cache_region->region) {
155155
if (touched_info->read) {
156156
read_access_indices.push_back(Substitute(range->min, var_map));
157+
read_access_region.push_back(Range::FromMinExtent(read_access_indices.back(), Integer(1)));
157158
} else {
158159
write_access_indices.push_back(Substitute(range->min, var_map));
160+
write_access_region.push_back(Range::FromMinExtent(write_access_indices.back(), Integer(1)));
159161
}
160162
}
161163
for (const IterVar& block_var : block_vars) {
162164
if (touched_info->read) {
163165
write_access_indices.push_back(block_var->var);
166+
write_access_region.push_back(Range::FromMinExtent(write_access_indices.back(), Integer(1)));
164167
} else {
165168
read_access_indices.push_back(block_var->var);
169+
read_access_region.push_back(Range::FromMinExtent(read_access_indices.back(), Integer(1)));
166170
}
167171
}
168172

169173
// Create New Block
170174
Block block(
171175
/*iter_vars*/ std::move(block_vars),
172-
/*reads=*/{},
173-
/*writes=*/{},
176+
/*reads=*/{BufferRegion(info->read_buffer, read_access_region)},
177+
/*writes=*/{BufferRegion(info->write_buffer, write_access_region)},
174178
/*name_hint*/ cache_region->buffer->name + "_" + storage_scope,
175179
/*body=*/
176180
BufferStore(info->write_buffer, BufferLoad(info->read_buffer, read_access_indices),
@@ -187,8 +191,8 @@ Block MakeReverseCacheStage(const BufferRegion& cache_region, ReverseCacheTouche
187191
// Create surrounding loops
188192
for (size_t i = loop_vars.size(); i >= 1; --i) {
189193
body = For(/*loop_var=*/loop_vars[i - 1],
190-
/*min=*/touched_info->loop_ranges[i]->min,
191-
/*extent=*/touched_info->loop_ranges[i]->extent,
194+
/*min=*/touched_info->loop_ranges[i - 1]->min,
195+
/*extent=*/touched_info->loop_ranges[i - 1]->extent,
192196
/*kind=*/ForKind::kSerial,
193197
/*body=*/body);
194198
}
@@ -487,6 +491,112 @@ class CacheLocDetector : public StmtVisitor {
487491
int loc_pos_{-1};
488492
};
489493

494+
/*! \brief Mutator for ReverseCacheRead. */
495+
class ReverseCacheReadRewriter : public StmtExprMutator {
496+
public:
497+
/*!
498+
* \brief Rewrite the AST and add a cache_read stage with the information provided.
499+
* \param scope_sref The parent scope of this mutation.
500+
* \param info The cache stage information.
501+
* \param touched_info The reverse cache touched information.
502+
* \return The new AST rooting at the original parent scope.
503+
*/
504+
static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info,
505+
ReverseCacheTouchedInfo* touched_info) {
506+
ReverseCacheReadRewriter rewriter(scope_sref, info, touched_info);
507+
return rewriter(GetRef<Stmt>(scope_sref->stmt));
508+
}
509+
510+
private:
511+
explicit ReverseCacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info,
512+
ReverseCacheTouchedInfo* touched_info)
513+
: scope_sref_(scope_sref), info_(info) {
514+
for (const IterVar& iter_var : touched_info->block_vars) {
515+
new_indices_.push_back(iter_var->var);
516+
}
517+
}
518+
519+
Stmt VisitStmt_(const ForNode* loop) final {
520+
Stmt stmt = StmtMutator::VisitStmt_(loop);
521+
// Check the insertion point
522+
if (loop == info_->loc_sref->stmt) {
523+
// Insert cache stage into the loop if it is the right place
524+
ObjectPtr<ForNode> n = make_object<ForNode>(*stmt.as<ForNode>());
525+
n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
526+
stmt = Stmt(n);
527+
}
528+
return stmt;
529+
}
530+
531+
Stmt VisitStmt_(const BlockNode* block) final {
532+
Block old_stmt = GetRef<Block>(block);
533+
if (block != scope_sref_->stmt &&
534+
GetBufferRegionFromBuffer(block->writes, info_->read_buffer).defined()) {
535+
return std::move(old_stmt);
536+
}
537+
// Mutate the body
538+
Block stmt = Downcast<Block>(StmtMutator::VisitStmt_(block));
539+
// Check the insertion point
540+
if (block == info_->loc_sref->stmt) {
541+
// Insert cache stage into the block if it is the right place
542+
ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
543+
n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
544+
stmt = Block(n);
545+
}
546+
// Check if it is the block corresponding to the parent scope
547+
if (block == scope_sref_->stmt) {
548+
// If so, put buffer allocation on the parent scope
549+
ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
550+
n->alloc_buffers.push_back(info_->alloc);
551+
stmt = Block(n);
552+
} else {
553+
// Otherwise, update read regions and match_buffers
554+
Array<BufferRegion> reads;
555+
for (const BufferRegion& buf_region : block->reads) {
556+
if (buf_region->buffer.same_as(info_->read_buffer)) {
557+
Array<Range> region;
558+
for (const PrimExpr index : new_indices_) {
559+
region.push_back(Range::FromMinExtent(index, Integer(1)));
560+
}
561+
reads.push_back(BufferRegion(info_->write_buffer, region));
562+
} else {
563+
reads.push_back(buf_region);
564+
}
565+
}
566+
567+
CHECK_EQ(block->match_buffers.size(), 0) << "Not supported yet.";
568+
if (!reads.same_as(block->reads)) {
569+
ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
570+
n->reads = std::move(reads);
571+
stmt = Block(n);
572+
}
573+
}
574+
info_->block_reuse.Set(old_stmt, stmt);
575+
return std::move(stmt);
576+
}
577+
578+
PrimExpr VisitExpr_(const VarNode* op) final {
579+
if (op == info_->read_buffer->data.get()) {
580+
return info_->write_buffer->data;
581+
}
582+
return GetRef<PrimExpr>(op);
583+
}
584+
585+
PrimExpr VisitExpr_(const BufferLoadNode* load) final {
586+
if (load->buffer.same_as(info_->read_buffer)) {
587+
ObjectPtr<BufferLoadNode> n = make_object<BufferLoadNode>(*load);
588+
n->buffer = info_->write_buffer;
589+
n->indices = new_indices_;
590+
return PrimExpr(n);
591+
}
592+
return ExprMutator::VisitExpr_(load);
593+
}
594+
595+
const StmtSRef& scope_sref_;
596+
CacheStageInfo* info_;
597+
Array<PrimExpr> new_indices_;
598+
};
599+
490600
/*! \brief Mutator for CacheRead. */
491601
class CacheReadRewriter : public StmtExprMutator {
492602
public:
@@ -881,35 +991,25 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re
881991
// Step 2. Create CacheStageInfo
882992
CacheStageInfo info;
883993
info.read_buffer = read_buffer;
884-
// Create the corresponding buffer to be written, i.e. result of cache_read
885-
info.write_buffer = WithScope(read_buffer, storage_scope);
886-
// Create the corresponding buffer allocation
887-
info.alloc = info.write_buffer;
888994
info.annotations = block->annotations;
889995

890996
// Step 3. Update cache stage info.
891-
BufferRegion cache_region{nullptr};
997+
Optional<BufferRegion> maybe_region = GetBufferRegionFromBuffer(block->reads, read_buffer);
998+
ICHECK(maybe_region.defined()) << read_buffer
999+
<< " should appear in the block's read region: " << block->reads;
1000+
BufferRegion cache_region = maybe_region.value();
8921001
if (Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) {
8931002
// Case 1. The buffer is written inside the block.
8941003
StmtSRef write_block_sref = _write_block_sref.value();
8951004
const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block, write_block_sref);
8961005
// Find the producing region
897-
BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, read_buffer).value();
8981006
StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent);
899-
9001007
// Detect insert position
9011008
CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info);
902-
cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref);
9031009
} else {
9041010
// Case 2. The buffer is the input block for the scope.
9051011
info.loc_sref = scope_sref;
9061012
info.loc_pos = 0;
907-
if (Optional<BufferRegion> region =
908-
GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) {
909-
cache_region = region.value();
910-
} else {
911-
cache_region = BufferRegion::FullRegion(read_buffer);
912-
}
9131013
}
9141014

9151015
// Step 4. Create CacheTouchedInfo
@@ -918,6 +1018,7 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re
9181018
// Step 5. Update CacheTouchedInfo
9191019
touched_info.read = true;
9201020
VarCollector collector;
1021+
Array<PrimExpr> new_shape;
9211022
for (const Range& range : cache_region->region) {
9221023
collector(range->min);
9231024
}
@@ -926,6 +1027,7 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re
9261027
IterVar block_var = block->iter_vars[i];
9271028
if (collector.touched.count(block_var->var.get())) {
9281029
touched_info.block_vars.push_back(block_var);
1030+
new_shape.push_back(block_var->dom->min + block_var->dom->extent);
9291031
touched_info.iter_values.push_back(realize->iter_values[i]);
9301032
collector(touched_info.iter_values.back());
9311033
}
@@ -938,11 +1040,24 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re
9381040
}
9391041
}
9401042

1043+
// Create write buffer.
1044+
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*read_buffer.get());
1045+
ObjectPtr<VarNode> new_var = make_object<VarNode>(*read_buffer->data.get());
1046+
const auto* ptr_type = TVM_TYPE_AS(ptr_type, read_buffer->data->type_annotation, PointerTypeNode);
1047+
new_var->type_annotation = PointerType(ptr_type->element_type, storage_scope);
1048+
new_buffer->data = Var(new_var->name_hint + "_" + storage_scope, new_var->type_annotation);
1049+
new_buffer->name = read_buffer->name + "_" + storage_scope;
1050+
new_buffer->shape = new_shape;
1051+
1052+
info.write_buffer = Buffer(new_buffer);
1053+
info.alloc = info.write_buffer;
1054+
9411055
// Step 6. Making new cache stage block and rewrite readers.
9421056
Block cache_read_stage = MakeReverseCacheStage(/*cache_region=*/cache_region,
9431057
/*touched_info*/ &touched_info, /*info=*/&info,
9441058
/*storage_scope=*/storage_scope);
945-
Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info);
1059+
Stmt new_scope = ReverseCacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info,
1060+
/*touched_info=*/&touched_info);
9461061

9471062
// Step 7. Replacing and updating flags.
9481063
self->Replace(scope_sref, new_scope, info.block_reuse);
@@ -956,62 +1071,7 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re
9561071

9571072
StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
9581073
const String& storage_scope) {
959-
/*!
960-
* Check:
961-
* - The index is in the array of block reading region
962-
* - There is only one block who write the buffer in the scope
963-
*
964-
* Mutate:
965-
* - Allocate new cache buffer under the current scope.
966-
* - Find the lowest ancestor of the block and ANY ONE of the producer blocks.
967-
* - Copy the buffer with the consumed region.
968-
*/
969-
970-
// Step 0. Check the input storage scope.
971-
CheckStorageScope(self, storage_scope);
972-
973-
// Step 1. Checking index, getting the target buffer and the parent scope
974-
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
975-
Buffer write_buffer =
976-
GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, /*is_write=*/true);
977-
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
978-
979-
// Step 2. Creating CacheStageInfo
980-
CacheStageInfo info;
981-
info.read_buffer = WithScope(write_buffer, storage_scope);
982-
// Create the corresponding buffer to be written, i.e. result of cache_write
983-
info.write_buffer = write_buffer;
984-
// Create the corresponding buffer allocation
985-
info.alloc = info.read_buffer;
986-
info.annotations = block->annotations;
987-
988-
// Step 3. Check the only writer block.
989-
if (!IsHorizontalFuse(self)) {
990-
ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get());
991-
}
992-
993-
// Step 4. Find the producing region and insert position
994-
BufferRegion region = GetBufferRegionFromBuffer(block->writes, write_buffer).value();
995-
StmtSRef parent_sref = GetRef<StmtSRef>(block_sref->parent);
996-
// Detect insert position
997-
CacheLocDetector::Detect(self, block_sref, scope_sref, &info);
998-
BufferRegion cache_region =
999-
RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref);
1000-
1001-
// Step 5. Making new cache stage block and rewrite readers.
1002-
Block cache_write_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info,
1003-
/*storage_scope=*/storage_scope);
1004-
Stmt new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref,
1005-
/*writer_block_sref=*/block_sref, /*info=*/&info);
1006-
1007-
// Step 6. Replacing and updating flags.
1008-
self->Replace(scope_sref, new_scope, info.block_reuse);
1009-
StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get());
1010-
BlockInfo& block_info = self->block_info[result_block_sref];
1011-
block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
1012-
block_info.region_cover = true;
1013-
block_info.scope->stage_pipeline = true;
1014-
return result_block_sref;
1074+
LOG(FATAL) << "Not implemented yet.";
10151075
}
10161076

10171077
/******** Instruction Registration ********/

tests/python/sparsetir/test_reverse_cache_read_write.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def test_tc_spmm_cache_read():
6969
fo, fi = sch.split(f, [None, 16])
7070
sch.reorder(fo, ii, ji, fi)
7171
new_blk = sch.blockize(ii)
72-
B_local = sch.reverse_cache_read(blk_inner, 2, "shared")
72+
B_shared = sch.reverse_cache_read(blk_inner, 2, "shared")
73+
B_warp = sch.reverse_cache_read(blk_inner, 2, "wmma.matrix_b")
7374
print(sch.mod["main"].script())
7475

7576

0 commit comments

Comments
 (0)