From 65cbd5d9fd62923de01562108f921bde1067487b Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 1 Sep 2021 03:59:53 +0800 Subject: [PATCH] [TensorIR][M2a] CacheRead/Write (#8863) Co-authored-by: Junru Shao Co-authored-by: Wuwei Lin Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> --- include/tvm/tir/schedule/schedule.h | 22 + include/tvm/tir/schedule/state.h | 5 + python/tvm/tir/schedule/schedule.py | 135 +++ src/tir/schedule/analysis.h | 21 +- src/tir/schedule/analysis/analysis.cc | 50 +- src/tir/schedule/concrete_schedule.cc | 21 + src/tir/schedule/concrete_schedule.h | 4 + src/tir/schedule/primitive.h | 24 + src/tir/schedule/primitive/block_annotate.cc | 4 +- .../schedule/primitive/cache_read_write.cc | 781 ++++++++++++++++++ src/tir/schedule/schedule.cc | 4 + src/tir/schedule/state.cc | 18 + src/tir/schedule/traced_schedule.cc | 23 + src/tir/schedule/traced_schedule.h | 4 + src/tir/schedule/transform.cc | 40 + src/tir/schedule/transform.h | 29 + src/tir/schedule/utils.h | 1 + .../test_tir_schedule_cache_read_write.py | 677 +++++++++++++++ 18 files changed, 1840 insertions(+), 23 deletions(-) create mode 100644 src/tir/schedule/primitive/cache_read_write.cc create mode 100644 tests/python/unittest/test_tir_schedule_cache_read_write.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 79fed09c3e36..33776cbe1985 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -282,6 +282,28 @@ class ScheduleNode : public runtime::Object { */ virtual void Unroll(const LoopRV& loop_rv) = 0; /******** Schedule: Insert cache stages ********/ + /*! + * \brief Create a block that reads a buffer region into a read cache. It requires: + * 1) There is at most one block who writes the buffer in the scope. + * 2) The scope block have stage-pipeline property. + * \param block_rv The consumer block of the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope. + * \return The cache stage block. + */ + virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) = 0; + /*! + * \brief Create a block that writes a buffer region into a write cache. It requires: + * 1) There is only one block who writes the target buffer. + * 2) The scope block have stage-pipeline property. + * \param block_rv The producer of the buffer + * \param write_buffer_index The index of the buffer in block's write region + * \param storage_scope The target storage scope + * \return The cache stage block. + */ + virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) = 0; /******** Schedule: Compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 7cd1b00c15ef..35299a3fa84b 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -128,6 +128,11 @@ class ScheduleStateNode : public Object { */ TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt, const Map& block_sref_reuse); + /*! + * \brief Recalculate the `affine_binding` flag of the scope block info. + * \param scope_sref The sref to the interested scope block. + */ + TVM_DLL void UpdateAffineFlag(const StmtSRef& scope_sref); /*! * \brief Trigger the verification according to the `debug_mask` bitmask. * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 9433d019f9a5..ac09bdbb264d 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -790,6 +790,141 @@ def after_unroll(a: ty.handle, b: ty.handle) -> None: ########## Schedule: Insert cache stages ########## + def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) -> BlockRV: + """Create a block that reads a buffer region into a read cache. It requires: + + 1) There is at most one block who write the buffer in the scope. + + 2) The scope block have stage-pipeline property. + + Parameters + ---------- + block : BlockRV + The consumer block of the target buffer. + + read_buffer_index: int + The index of the buffer in block's read region. + + storage_scope: str + The target storage scope. + + Returns + ------- + cached_block : BlockRV + The block of the cache stage + + Examples + -------- + Before cache_read, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_cache_read(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and cache_read: + + .. code-block:: python + + sch = tir.Schedule(before_cache_read) + block_b = sch.get_block("B") + sch.cache_read(block_b, 0, "local") + print(tvm.script.asscript(sch.mod["main"])) + + After applying cache_read, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_cache_read(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + A_local = tir.alloc_buffer((128, 128), scope="local") + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "A_local") as [vi, vj]: + A_local[vi, vj] = A[vi, vj] + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A_local[vi, vj] * 2.0 + + """ + return _ffi_api.ScheduleCacheRead( # type: ignore # pylint: disable=no-member + self, block, read_buffer_index, storage_scope + ) + + def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: str) -> BlockRV: + """Create a block that reads a buffer region into a write cache. It requires: + + 1) There is only one block who write the buffer in the scope. + + 2) The scope block have stage-pipeline property. + + Parameters + ---------- + block : BlockRV + The producer block of the target buffer. + + write_buffer_index: int + The index of the buffer in block's write region. + + storage_scope: str + The target storage scope. + + + Returns + ------- + cached_block : BlockRV + The block of the cache stage + + Examples + -------- + Before cache_write, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_cache_write(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and cache_write: + + .. code-block:: python + + sch = tir.Schedule(before_cache_write) + block_b = sch.get_block("B") + sch.cache_write(block_b, 0, "local") + print(tvm.script.asscript(sch.mod["main"])) + + After applying cache_write, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_cache_write(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + B_local = tir.alloc_buffer((128, 128), scope="local") + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "A_local") as [vi, vj]: + B_local[vi, vj] = A[vi, vj] * 2.0 + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = B_local[vi, vj] + + """ + return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member + self, block, write_buffer_index, storage_scope + ) + ########## Schedule: Compute location ########## def compute_inline(self, block: BlockRV) -> None: diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 3fa0c63b2e2f..d4e4728abfe0 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -56,6 +56,13 @@ void VerifyCachedFlags(const ScheduleState& self); const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, GlobalVar* result_g_var); +/*! + * \brief Get the root node of the sref tree, which is the root block of the PrimFunc. + * \param sref The given sref. + * \return The root node of the sref tree which contains the given node. + */ +StmtSRef GetSRefTreeRoot(const StmtSRef& sref); + /******** Scope ********/ /*! * \brief Checks if scope the specified sref is in is a stage-pipeline and return it @@ -228,15 +235,15 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr /******** Block-buffer relation ********/ /*! - * \brief Get the BlockRealize of the single child block of the block or loop specified by - * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks - * \param self The schedule state - * \param block The queried block - * \param n The index of the queried buffer - * \return The buffer of the n-th write region of the block. + * \brief Get the n-th read or write buffer of the given block. + * \param self The schedule state. + * \param block The queried block. + * \param n The index of the queried buffer. + * \param is_write A boolean flag to indicate querying write buffer or read buffer. + * \return The buffer of the n-th read/write region of the block. * \throw ScheduleError If the buffer index is out of bound. */ -Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n); +Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write); /******** Commutative Reducer ********/ diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index c9f8ff4c7e75..3865781c5870 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -588,25 +588,37 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr /******** Block-buffer relation ********/ -Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) { - class WriteBufferIndexOutOfRangeError : public ScheduleError { +Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write) { + class BufferIndexOutOfRangeError : public ScheduleError { public: - explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index) - : mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {} + explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index, bool is_write) + : mod_(std::move(mod)), + block_(std::move(block)), + buffer_index_(buffer_index), + is_write_(is_write) {} String FastErrorString() const final { - return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " - "range [0, num_write_regions) where `num_write_regions` is the number of buffer " - "regions written by the block."; + if (is_write_) { + return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " + "range " + "[0, num_write_regions) where `num_write_regions` is the number of buffer regions " + "written by the block."; + } else { + return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " + "range " + "[0, num_read_regions) where `num_read_regions` is the number of buffer regions " + "read by the block."; + } } String DetailRenderTemplate() const final { std::ostringstream os; - size_t num_writes = block_->writes.size(); - os << "The block {0} has " << num_writes - << " write regions, so `buffer_index` is required to be in [0, " << num_writes + size_t num = is_write_ ? block_->writes.size() : block_->reads.size(); + std::string access_type = is_write_ ? "write" : "read"; + os << "The block {0} has " << num << " " << access_type + << " regions, so `buffer_index` is required to be in [0, " << num << "). However, the input `buffer_index` is " << buffer_index_ - << ", which is out of the expected range"; + << ", which is out of the expected range."; return os.str(); } @@ -617,12 +629,15 @@ Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) { IRModule mod_; Block block_; int buffer_index_; + bool is_write_; }; - if (n < 0 || static_cast(n) >= block->writes.size()) { - throw WriteBufferIndexOutOfRangeError(self->mod, block, n); + const Array& access_region = is_write ? block->writes : block->reads; + + if (n < 0 || static_cast(access_region.size()) <= n) { + throw BufferIndexOutOfRangeError(self->mod, block, n, is_write); } - return block->writes[n]->buffer; + return access_region[n]->buffer; } /******** Pattern Matcher ********/ @@ -941,5 +956,12 @@ bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, return false; } +/******** SRef Tree Related ********/ +StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { + const StmtSRefNode* p = sref.get(); + for (; p->parent != nullptr; p = p->parent) { + } + return GetRef(p); +} } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index cd9aad8ae512..86223e11c196 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -416,6 +416,27 @@ void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) { } /******** Schedule: Insert cache stages ********/ + +BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope); + TVM_TIR_SCHEDULE_END("cache-read", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); + TVM_TIR_SCHEDULE_END("cache-write", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 0bd902d183bf..e756f9da41b2 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -103,6 +103,10 @@ class ConcreteScheduleNode : public ScheduleNode { void Bind(const LoopRV& loop_rv, const String& thread_axis) override; void Unroll(const LoopRV& loop_rv) override; /******** Schedule: Insert cache stages ********/ + BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) override; + BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) override; /******** Schedule: Compute location ********/ void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index be33c2acca10..412611adf76d 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -135,6 +135,30 @@ TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& */ TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref); /******** Schedule: Insert cache stages ********/ +/*! + * \brief Create a block that reads a buffer region into a read cache. It requires: + * 1) There is at most one block who writes the buffer in the scope. + * 2) The scope block have stage-pipeline property. + * \param self The state of the schedule + * \param block_sref The consumer block of the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope. + * \return The cache stage block. + */ +TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, + const String& storage_scope); +/*! + * \brief Create a block that writes a buffer region into a write cache. It requires: + * 1) There is only one block that writes the target buffer. + * 2) The scope block have stage-pipeline property. + * \param self The state of the schedule + * \param block_sref The producer of the buffer + * \param write_buffer_index The index of the buffer in block's write region + * \param storage_scope The target storage scope + * \return The cache stage block. + */ +TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, + const String& storage_scope); /******** Schedule: Compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 937bc7c3802f..06f7ac3c1bc2 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include "../transform.h" #include "../utils.h" namespace tvm { @@ -237,7 +236,8 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, int factor, int offset) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); - Buffer buffer = GetNthWriteBuffer(self, GetRef(block_ptr), buffer_index); + Buffer buffer = + GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, /*is_write=*/true); StorageAlignInvalidFactorError::Check(self->mod, factor); axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis); NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer); diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc new file mode 100644 index 000000000000..df54c9652ece --- /dev/null +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -0,0 +1,781 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../utils.h" + +namespace tvm { +namespace tir { + +/******** Error Classes ********/ + +class NotSingleWriteBlock : public ScheduleError { + public: + explicit NotSingleWriteBlock(IRModule mod, Buffer buffer, Array write_blocks) + : mod_(std::move(mod)), buffer_(std::move(buffer)) { + ICHECK_GT(write_blocks.size(), 1); + write_blocks_.reserve(write_blocks.size()); + for (const StmtSRef& block_sref : write_blocks) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + write_blocks_.push_back(GetRef(block)); + } + } + + String FastErrorString() const final { + return "ScheduleError: The buffer is allowed to be written by single block."; + } + + String DetailRenderTemplate() const final { + size_t k = write_blocks_.size(); + return "The buffer " + buffer_->name + " is expected to be written by single block, but got " + + std::to_string(k) + " blocks who write it."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { + return {write_blocks_.begin(), write_blocks_.end()}; + } + + private: + IRModule mod_; + Buffer buffer_; + Array write_blocks_; +}; + +/******** Helper Functions/Classes ********/ + +/*! \brief The auxiliary info used for the insertion point and content of the cache stage. */ +struct CacheStageInfo { + /*! \brief The buffer to be read. */ + Buffer read_buffer; + /*! \brief The buffer to be written. */ + Buffer write_buffer; + /*! \brief The buffer allocation to be inserted into the block signature. */ + Buffer alloc; + /*! \brief The AST node whose body is where the cache stage should be inserted. */ + StmtSRef loc_sref; + /*! \brief The index to insert the cache_read/cache_write stage. */ + size_t loc_pos; + /*! \brief The cache_read/cache_write stage to be inserted. */ + Stmt cache_stage; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map block_reuse; +}; + +/*! \brief Return the buffer region realted with the buffer */ +Optional GetBufferRegionFromBuffer(const Array& buffer_regions, + const Buffer& buffer) { + Optional res = NullOpt; + for (const auto& region : buffer_regions) { + if (region->buffer.same_as(buffer)) { + ICHECK(!res.defined()); + res = region; + } + } + return res; +} + +/*! + * \brief Create a loop nest that represents cache copy (cache_read / cache_write) from read buffer + * to write buffer. + * \note This function will store the stmt with loop nesting to the CacheStageInfo, but only return + * the inside block. + * \param cache_region The cached copy region. + * \param info The cache stage information, which will be updated in the function. + * \param storage_scope The storage scope of the cached buffer (only used in naming here) + * \returns A block indicating the body of the loop nesting. + */ +Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, + const String& storage_scope) { + // loop variables + std::vector loop_vars; + // bindings in block realize + std::vector iter_values; + // Create loop vars and block vars' binding_value + for (const Range& axis_range : cache_region->region) { + Var loop_var("ax" + std::to_string(loop_vars.size())); + loop_vars.push_back(loop_var); + iter_values.push_back(axis_range->min + loop_var); + } + // block variables + Array block_vars; + // block access region for read/write buffers + Region access_region; + // indices used in block body + Array access_indices; + // Create block vars, block's accessed region and accessing indices + for (const PrimExpr& dim : cache_region->buffer->shape) { + Var var("v" + std::to_string(access_indices.size())); + block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(0, dim), + /*var=*/var, + /*IterVarType=*/kDataPar)); + access_indices.push_back(var); + access_region.push_back(Range::FromMinExtent(var, 1)); + } + + // Create the body block: + // reads = [read_buffer[access_region]] + // writes = [write_buffer[access_region]] + // write_buffer[access_indices] = read_buffer[access_indices] + Block block( + /*iter_vars=*/std::move(block_vars), + /*reads=*/{BufferRegion(info->read_buffer, access_region)}, + /*writes=*/{BufferRegion(info->write_buffer, access_region)}, + /*name_hint=*/cache_region->buffer->name + "_" + storage_scope, + /*body=*/ + BufferStore(info->write_buffer, BufferLoad(info->read_buffer, access_indices), + access_indices), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/{}); + // Create the block realize node + Stmt body = BlockRealize(/*values=*/iter_values, + /*predicate=*/Bool(true), + /*block=*/block); + // Create surrounding loops + for (size_t i = loop_vars.size(); i >= 1; --i) { + body = For(/*loop_var=*/loop_vars[i - 1], + /*min=*/0, + /*extent=*/cache_region->region[i - 1]->extent, + /*kind=*/ForKind::kSerial, + /*body=*/body); + } + info->cache_stage = std::move(body); + return block; +} + +/*! + * \brief Insert the cache_read/cache_write stage into the specific position + * \param stmt A sequence of statements or a single statement that the new stage is inserted in + * \param pos The position where the cache stage is inserted + * \param stage The stage to be inserted + * \return A SeqStmt, the result after insertion + */ +SeqStmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { + if (const auto* seq_stmt = stmt.as()) { + ObjectPtr result = make_object(*seq_stmt); + result->seq.insert(result->seq.begin() + pos, stage); + return SeqStmt(result); + } + if (pos == 0) { + return SeqStmt({stage, stmt}); + } + ICHECK_EQ(pos, 1); + return SeqStmt({stmt, stage}); +} + +/*! + * \brief Get the only writer block of the input buffer in a given scope block. + * \param self The state of the schedule + * \param scope_sref The scope block where the write is considered + * \param buffer The queried buffer + * \return The sref of the only writer of the input buffer in the given scope, + * or `NullOpt` if no block writes it in the scope. + * \throw NotSingleWriteBlock if there are more than one intrested block. + */ +Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_sref, + const Buffer& buffer) { + BlockScope scope = self->GetBlockScope(scope_sref); + auto it = scope->buffer_writers.find(buffer); + if (it == scope->buffer_writers.end()) { + return NullOpt; + } else { + const Array& block_srefs = it->second; + ICHECK(!block_srefs.empty()); + if (block_srefs.size() > 1) { + throw NotSingleWriteBlock(self->mod, buffer, block_srefs); + } + return block_srefs[0]; + } +} + +/*! + * \brief Get the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) + * \param self The state of the schedule. + * \param buffer_region The buffer region to be analyzed. + * \param block_sref The sref of the block related to the region. + * \param dom_low_inclusive The lowest node in the sref tree path. + * \param dom_high_exclusive The highest node in the sref tree path. + * \return The relaxed buffer region. + */ +BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_region, + const StmtSRef& block_sref, const StmtSRef& dom_low_inclusive, + const StmtSRef& dom_high_exclusive) { + BlockRealize realize = GetBlockRealize(self, block_sref); + Map binding = GetBindings(realize); + const Buffer& buffer = buffer_region->buffer; + Array int_sets = + arith::EvalSet(Substitute(buffer_region->region, binding), + AsIntSet(LoopDomainOfSRefTreePath( + /*low_inclusive=*/dom_low_inclusive, + /*high_exclusive=*/dom_high_exclusive, + /*extra_relax_scope=*/runtime::StorageScope::Create(buffer.scope())))); + ICHECK_EQ(buffer_region->region.size(), int_sets.size()); + + Region region; + region.reserve(int_sets.size()); + for (size_t i = 0; i < int_sets.size(); ++i) { + region.push_back(int_sets[i].CoverRange(Range::FromMinExtent(0, buffer->shape[i]))); + } + return BufferRegion(buffer, region); +} + +/*! \brief Detect the insertion position of the new cache stage */ +class CacheLocDetector : public StmtVisitor { + public: + /*! + * \brief Detect the insertion position of the cache stage, and write the position into the + * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique + * writer block of the buffer being applied cache_read or cache_write \param scope_sref The sref + * of the scope block of the cached block \param info The cache stage info. + */ + static void Detect(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_sref, CacheStageInfo* info) { + std::vector related_blocks; + for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) { + if (def->kind == DepKind::kRAW) { + related_blocks.push_back(def->dst); + } + } + if (!related_blocks.empty()) { + CacheLocDetector detector(self, block_sref, scope_sref, related_blocks); + detector(GetRef(scope_sref->stmt)); + info->loc_sref = detector.loc_sref_; + info->loc_pos = detector.loc_pos_; + } else { + info->loc_sref = scope_sref; + const auto* body = scope_sref->StmtAs()->body.as(); + info->loc_pos = body == nullptr ? 1 : body->size(); + } + } + + private: + /*! + * \brief Constructor + * \param self The state of the schedule + * \param block_sref The sref of the unique writer block of the buffer being applied cache_read or + * cache_write \param scope_sref The sref of the scope block of the cached block \param + * related_blocks Producer blocks for cache_write, or consumer blocks for cache_read + */ + CacheLocDetector(const ScheduleState self, const StmtSRef& block_sref, const StmtSRef& scope_sref, + const std::vector& related_blocks) + : self_(self), + block_sref_(block_sref), + scope_sref_(scope_sref), + related_blocks_(related_blocks) {} + + void VisitStmt_(const SeqStmtNode* seq_stmt) final { + bool previous_visited_block = visited_block_; + bool previous_visited_related = visited_related_; + visited_block_ = visited_related_ = false; + + int pos = -1; + for (size_t i = 0; i < seq_stmt->size(); ++i) { + if (loc_pos_ != -1) { + break; + } + VisitStmt(seq_stmt->seq[i]); + // `pos` can be assigned only once when we visited `block_sref` + if (visited_block_ && visited_related_ && pos == -1) { + // The offset of insert position from the block + pos = i; + } + } + visited_block_ = visited_block_ || previous_visited_block; + visited_related_ = visited_related_ || previous_visited_related; + // Only we visited the writing block and any one of the related blocks + // That means that we have found the lowest ancestor + // of the block and any one of the related ones + if (visited_block_ && visited_related_ && loc_pos_ == -1) { + loc_pos_ = pos; + } + } + + void VisitStmt_(const BlockNode* block) final { + // Only visit the current scope under buffer writer's parent block + if (block == scope_sref_->stmt) { + // The block vistied is the current parent scope + StmtVisitor::VisitStmt_(block); + // Handling cache_read for input buffer + if (visited_block_ && visited_related_ && !loc_sref_.defined()) { + loc_sref_ = self_->stmt2ref.at(block); + if (loc_pos_ == -1) { + loc_pos_ = 1; + } + } + return; + } + // Update `visited_block` + if (block_sref_->stmt == block) { + visited_block_ = true; + return; + } + // Update `visited_related` + for (const StmtSRef& related_block : related_blocks_) { + if (related_block->stmt == block) { + visited_related_ = true; + return; + } + } + } + + void VisitStmt_(const ForNode* loop) final { + StmtVisitor::VisitStmt_(loop); + if (visited_block_ && visited_related_ && !loc_sref_.defined() && loc_pos_ != -1) { + loc_sref_ = self_->stmt2ref.at(loop); + } + } + + private: + /*! \brief The schedule class */ + const ScheduleState self_; + /*! \brief The dominate block which write the buffer */ + const StmtSRef& block_sref_; + /*! \brief The parent scope of the dominate block */ + const StmtSRef& scope_sref_; + /*! \brief Producer blocks for cache_write and consumer blocks for cache_read */ + const std::vector& related_blocks_; + /*! \brief The flag whether we have visited the dominate block */ + bool visited_block_{false}; + /*! \brief The flag whether we have visited at least one related blocks */ + bool visited_related_{false}; + /*! \brief The AST node whose body is where the cache stage should be inserted */ + StmtSRef loc_sref_{nullptr}; + /*! \brief The index to insert the cache_read/cache_write stage */ + int loc_pos_{-1}; +}; + +/*! \brief Mutator for CacheRead. */ +class CacheReadRewriter : public StmtExprMutator { + public: + /*! + * \brief Rewrite the AST and add a cache_read stage with the information provided + * \param scope_sref The parent scope of this mutation + * \param info The cache stage information + * \return The new AST rooting at the original parent scope + */ + static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info) { + CacheReadRewriter rewriter(scope_sref, info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info) + : scope_sref_(scope_sref), info_(info) {} + + Stmt VisitStmt_(const ForNode* loop) final { + Stmt stmt = StmtMutator::VisitStmt_(loop); + // Check the insertion point + if (loop == info_->loc_sref->stmt) { + // Insert cache stage into the loop if it is the right place + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + stmt = Stmt(n); + } + return stmt; + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef(block); + // We don't mutate the block which generates info->read_buffer + if (block != scope_sref_->stmt && + GetBufferRegionFromBuffer(block->writes, info_->read_buffer).defined()) { + return std::move(old_stmt); + } + // Mutate the body + Block stmt = Downcast(StmtMutator::VisitStmt_(block)); + // Check the insertion point + if (block == info_->loc_sref->stmt) { + // Insert cache stage into the block if it is the right place + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + stmt = Block(n); + } + // Check if it is the block corresponding to the parent scope + if (block == scope_sref_->stmt) { + // If so, put buffer allocation on the parent scope + ObjectPtr n = make_object(*stmt.as()); + n->alloc_buffers.push_back(info_->alloc); + stmt = Block(n); + } else { + // Otherwise, update read regions and match_buffers + Array reads = + ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer); + Array match_buffers = + ReplaceBuffer(block->match_buffers, info_->read_buffer, info_->write_buffer); + if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { + ObjectPtr n = make_object(*stmt.as()); + n->reads = std::move(reads); + n->match_buffers = std::move(match_buffers); + stmt = Block(n); + } + } + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(info_->read_buffer)) { + ObjectPtr n = make_object(*load); + n->buffer = info_->write_buffer; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + PrimExpr VisitExpr_(const LoadNode* load) final { + if (load->buffer_var.same_as(info_->read_buffer->data)) { + ObjectPtr n = make_object(*load); + n->buffer_var = info_->write_buffer->data; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + if (op == info_->read_buffer->data.get()) { + return info_->write_buffer->data; + } + return GetRef(op); + } + + private: + /*! \brief The parent scope of the insertion */ + const StmtSRef& scope_sref_; + /*! \brief The info for inserting cache stage */ + CacheStageInfo* info_; +}; + +/*! \brief Mutator for CacheWrite */ +class CacheWriteRewriter : public StmtExprMutator { + public: + /*! + * \brief Rewrite the AST and add a cache_write stage with the information provided. + * \param scope_sref The parent scope of this mutation. + * \param writer_block_sref The only writer block in the scope. + * \param info The cache stage information. + * \return The new AST rooting at the original parent scope. + */ + static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, + CacheStageInfo* info) { + CacheWriteRewriter rewriter(scope_sref, writer_block_sref, info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit CacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, + CacheStageInfo* info) + : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) {} + + Stmt VisitStmt_(const ForNode* loop) final { + Stmt stmt = StmtMutator::VisitStmt_(loop); + // Check the insertion point + if (loop == info_->loc_sref->stmt) { + // Insert cache stage into the loop if it is the right place + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + stmt = Stmt(n); + } + return stmt; + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef(block); + // We only mutate the block which generates info->write_buffer + if (block != writer_block_sref_->stmt && block != scope_sref_->stmt && !under_writer_block_) { + return std::move(old_stmt); + } + + // Mutate the body + bool under_scope = under_writer_block_ || block == writer_block_sref_->stmt; + std::swap(under_scope, under_writer_block_); + Block stmt = Downcast(StmtMutator::VisitStmt_(block)); + std::swap(under_scope, under_writer_block_); + + // Find the insertion point + if (block == info_->loc_sref->stmt) { + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + stmt = Block(n); + } + // Put buffer allocation on the parent scope + if (block == scope_sref_->stmt) { + ObjectPtr n = make_object(*stmt.as()); + n->alloc_buffers.push_back(info_->alloc); + stmt = Block(n); + } else { + // Since cache_write changes the block, we need to update the buffer it writes + auto writes = ReplaceBuffer(block->writes, info_->write_buffer, info_->read_buffer); + auto reads = ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); + auto match_buffers = + ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); + if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || + !match_buffers.same_as(block->match_buffers)) { + ObjectPtr n = make_object(*stmt.as()); + n->writes = std::move(writes); + n->reads = std::move(reads); + n->match_buffers = std::move(match_buffers); + stmt = Block(n); + } + } + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + Stmt VisitStmt_(const BufferStoreNode* store) final { + BufferStore stmt = Downcast(StmtMutator::VisitStmt_(store)); + if (stmt->buffer.same_as(info_->write_buffer)) { + auto n = CopyOnWrite(stmt.get()); + n->buffer = info_->read_buffer; + return Stmt(n); + } else { + return std::move(stmt); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(info_->write_buffer)) { + ObjectPtr n = make_object(*load); + n->buffer = info_->read_buffer; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + PrimExpr VisitExpr_(const LoadNode* load) final { + if (load->buffer_var.same_as(info_->write_buffer->data)) { + ObjectPtr n = make_object(*load); + n->buffer_var = info_->read_buffer->data; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + Stmt VisitStmt_(const StoreNode* store) final { + if (store->buffer_var.same_as(info_->write_buffer->data)) { + ObjectPtr n = make_object(*store); + n->buffer_var = info_->read_buffer->data; + return Stmt(n); + } + return StmtMutator::VisitStmt_(store); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + if (op == info_->write_buffer->data.get()) { + return info_->read_buffer->data; + } + return GetRef(op); + } + + private: + /*! \brief The parent scope of the insertion. */ + const StmtSRef& scope_sref_; + /*! \brief The parent scope of the insertion. */ + const StmtSRef& writer_block_sref_; + /*! \brief The info for inserting cache stage. */ + CacheStageInfo* info_; + /*! \brief Whether the current node is under the given block. */ + bool under_writer_block_{false}; +}; + +/******** Implementation ********/ + +StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, + const String& storage_scope) { + /*! + * Check: + * - The index is in the array of block reading region + * - There is at most one block who write the buffer in the scope + * + * Mutate: + * - Allocate new cache buffer under the current scope. + * - Find the lowest ancestor of the block and ANY ONE of the consumers blocks. + * - Copy the buffer with the consumed region. + */ + + // Step 1. Check index, getting the target buffer and the parent scope + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Buffer read_buffer = + GetNthAccessBuffer(self, GetRef(block), read_buffer_index, /*is_write=*/false); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + + // Step 2. Creat CacheStageInfo + CacheStageInfo info; + info.read_buffer = read_buffer; + // Create the corresponding buffer to be written, i.e. result of cache_read + info.write_buffer = WithScope(read_buffer, storage_scope); + // Create the corresponding buffer allocation + info.alloc = info.write_buffer; + + // Step 3. Update cache stage info. + BufferRegion cache_region{nullptr}; + if (Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { + // Case 1. The buffer is written inside the block. + StmtSRef write_block_sref = _write_block_sref.value(); + const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block, write_block_sref); + // Find the producing region + BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, read_buffer).value(); + StmtSRef parent_sref = GetRef(write_block_sref->parent); + + // Detect insert position + CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info); + cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref); + } else { + // Case 2. The buffer is the input block for the scope. + info.loc_sref = scope_sref; + info.loc_pos = 0; + if (Optional region = + GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) { + cache_region = region.value(); + } else { + cache_region = BufferRegion::FullRegion(read_buffer); + } + } + + // Step 4. Making new cache stage block and rewrite readers. + Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); + + // Step 5. Replacing and updating flags. + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); + self->UpdateAffineFlag(result_block_sref); + BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + return result_block_sref; +} + +StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, + const String& storage_scope) { + /*! + * Check: + * - The index is in the array of block reading region + * - There is only one block who write the buffer in the scope + * + * Mutate: + * - Allocate new cache buffer under the current scope. + * - Find the lowest ancestor of the block and ANY ONE of the producer blocks. + * - Copy the buffer with the consumed region. + */ + // Step 1. Checking index, getting the target buffer and the parent scope + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Buffer write_buffer = + GetNthAccessBuffer(self, GetRef(block), write_buffer_index, /*is_write=*/true); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + + // Step 2. Creating CacheStageInfo + CacheStageInfo info; + info.read_buffer = WithScope(write_buffer, storage_scope); + // Create the corresponding buffer to be written, i.e. result of cache_write + info.write_buffer = write_buffer; + // Create the corresponding buffer allocation + info.alloc = info.read_buffer; + + // Step 3. Check the only writer block. + ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); + + // Step 4. Find the producing region and insert position + BufferRegion region = GetBufferRegionFromBuffer(block->writes, write_buffer).value(); + StmtSRef parent_sref = GetRef(block_sref->parent); + // Detect insert position + CacheLocDetector::Detect(self, block_sref, scope_sref, &info); + BufferRegion cache_region = + RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref); + + // Step 5. Making new cache stage block and rewrite readers. + Block cache_write_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref, + /*writer_block_sref=*/block_sref, /*info=*/&info); + + // Step 6. Replacing and updating flags. + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get()); + self->UpdateAffineFlag(result_block_sref); + BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + return result_block_sref; +} + +/******** Instruction Registration ********/ + +struct CacheReadTraits : public UnpackedInstTraits { + static constexpr const char* kName = "CacheRead"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer read_buffer_index, + String storage_scope) { + return sch->CacheRead(block, read_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String block, Integer read_buffer_index, + String storage_scope) { + PythonAPICall py("cache_read"); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct CacheWriteTraits : public UnpackedInstTraits { + static constexpr const char* kName = "CacheWrite"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer write_buffer_index, + String storage_scope) { + return sch->CacheWrite(block, write_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String block, Integer write_buffer_index, + String storage_scope) { + PythonAPICall py("cache_write"); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits); +TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits); +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index d24cdc625912..fd30b02fc9dd 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -141,6 +141,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method(&ScheduleNode::Bind); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method(&ScheduleNode::Unroll); /******** (FFI) Insert cache stages ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") + .set_body_method(&ScheduleNode::CacheRead); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") + .set_body_method(&ScheduleNode::CacheWrite); /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") .set_body_method(&ScheduleNode::ComputeInline); diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 9a9b97497e04..799806bef7b5 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -1029,6 +1029,24 @@ TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& bl Bool(info.scope->stage_pipeline)}; } +TVM_DLL void ScheduleStateNode::UpdateAffineFlag(const StmtSRef& scope_sref) { + auto it = this->block_info.find(scope_sref); + ICHECK(it != this->block_info.end()) << "Cannot find the block info of the given block."; + BlockInfo& info = it->second; + + bool is_root_block = scope_sref->parent == nullptr; + if (is_root_block) { + info.affine_binding = true; + } else { + BlockRealize realize = GetBlockRealize(GetRef(this), scope_sref); + arith::Analyzer analyzer; + StmtSRef parent_sref = GetRef(scope_sref->parent); + info.affine_binding = IsAffineBinding(/*realize=*/realize, + /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), + /*analyzer=*/&analyzer); + } +} + /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(ScheduleStateNode); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index af4a6588f064..f429a917858b 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -166,6 +166,29 @@ void TracedScheduleNode::Unroll(const LoopRV& loop_rv) { } /******** Schedule: Insert cache stages ********/ +BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) { + BlockRV result = ConcreteScheduleNode::CacheRead(block_rv, read_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("CacheRead"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + +BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) { + BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, write_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("CacheWrite"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} /******** Schedule: Compute location ********/ diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 48dadbc03b3b..a6b5251a96a3 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -71,6 +71,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { void Bind(const LoopRV& loop_rv, const String& thread_axis) final; void Unroll(const LoopRV& loop_rv) final; /******** Schedule: Insert cache stages ********/ + BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) final; + BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) final; /******** Schedule: Compute location ********/ void ComputeInline(const BlockRV& block_rv) final; void ReverseComputeInline(const BlockRV& block_rv) final; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index f27e0f6d62eb..da376fdde90f 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -19,6 +19,8 @@ #include "./transform.h" +#include "./utils.h" + namespace tvm { namespace tir { @@ -31,5 +33,43 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec return Block(new_block); } +/******** Buffer Related ********/ +Buffer WithScope(const Buffer& buffer, const String& scope) { + ObjectPtr new_buffer = make_object(*buffer.get()); + ObjectPtr new_var = make_object(*buffer->data.get()); + const auto* ptr_type = TVM_TYPE_AS(ptr_type, buffer->data->type_annotation, PointerTypeNode); + new_var->type_annotation = PointerType(ptr_type->element_type, scope); + new_buffer->data = Var(new_var->name_hint + "_" + scope, new_var->type_annotation); + new_buffer->name = buffer->name + "_" + scope; + return Buffer(new_buffer); +} + +Array ReplaceBuffer(Array regions, const Buffer& source, + const Buffer& target) { + regions.MutateByApply([&source, &target](BufferRegion region) -> BufferRegion { + if (region->buffer.same_as(source)) { + ObjectPtr n = make_object(*region.get()); + n->buffer = target; + return BufferRegion(n); + } + return region; + }); + return regions; +} + +Array ReplaceBuffer(Array match_buffers, const Buffer& source, + const Buffer& target) { + match_buffers.MutateByApply([&source, + &target](MatchBufferRegion match_buffer) -> MatchBufferRegion { + if (match_buffer->source->buffer.same_as(source)) { + ObjectPtr n = make_object(*match_buffer.get()); + n->source = BufferRegion(target, n->source->region); + return MatchBufferRegion(n); + } + return match_buffer; + }); + return match_buffers; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 53483829a303..85cce9da216e 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -35,6 +35,35 @@ namespace tir { */ Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value); +/******** Buffer Related ********/ + +/*! + * \brief Create a new buffer by changing the storage scope. + * \param buffer The given buffer. + * \param scope The target storage scope. + * \return The new buffer with target storage scope. + */ +Buffer WithScope(const Buffer& buffer, const String& scope); + +/*! + * \brief Replaces the buffer within the specific sequence of regions + * \param regions The regions whose buffers are to be replaced + * \param source The buffer to be replaced + * \param target The buffer to be replaced to + * \return The new sequence of regions after replacement + */ +Array ReplaceBuffer(Array regions, const Buffer& source, + const Buffer& target); + +/*! + * \brief Replaces the buffer within the specific sequence of match_buffers + * \param match_buffers The match_buffers whose buffers are to be replaced + * \param source The buffer to be replaced + * \param target The buffer to be replaced to + * \return The new sequence of match_buffers after replacement + */ +Array ReplaceBuffer(Array match_buffers, const Buffer& source, + const Buffer& target); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 8ccf8da731b5..c2f430181664 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -42,6 +42,7 @@ #include "./error.h" #include "./instruction_traits.h" #include "./primitive.h" +#include "./transform.h" namespace tvm { namespace tir { diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py new file mode 100644 index 000000000000..d7eb8d864135 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -0,0 +1,677 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable + +########## Function before schedule ########## + + +@tvm.script.tir +def elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def access_under_scope(b: ty.handle, c: ty.handle) -> None: + A = tir.alloc_buffer((128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + + with tir.block([8, 8], "scope") as [i, j]: + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "A") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + A[vi, vj] = 1.0 + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + B[vi, vj] = A[vi, vj] + 1.0 + + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), dtype="float16") + B = tir.match_buffer(b, (128, 128), dtype="float16") + C = tir.match_buffer(c, (128, 128), dtype="float16") + D = tir.match_buffer(d, (128, 128), dtype="float16") + + with tir.block([128, 128], "load_store") as [vi, vj]: + tir.reads(A[vi, vj]) + tir.writes(D[vi, vj]) + D.data[vi * 128 + vj] = tir.load("float16", A.data, vi * 128 + vj) + with tir.block([8, 8], "opaque") as [vi, vj]: + tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.evaluate( + tir.tvm_load_matrix_sync( + B.data, + 16, + 16, + 16, + vi * 8 + vj, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), + 128, + "row_major", + dtype="handle", + ) + ) + with tir.block([8, 8], "match_buffer") as [vi, vj]: + tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = tir.match_buffer( + A[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + C0 = tir.match_buffer( + C[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", + dtype="handle", + ) + ) + + +@tvm.script.tir +def func_multi_consumer() -> None: + A = tir.alloc_buffer((128)) + B = tir.alloc_buffer((128)) + C = tir.alloc_buffer((128)) + for i in tir.grid(8): + for j in tir.grid(16): + with tir.block([128], "A") as [vi]: + tir.bind(vi, i * 16 + j) + A[vi] = 1.0 + for j in tir.grid(16): + with tir.block([128], "B") as [vi]: + tir.bind(vi, i * 16 + j) + B[vi] = A[vi] + 1.0 + for i in tir.grid(128): + with tir.block([128], "C") as [vi]: + C[vi] = A[vi] + + +@tvm.script.tir +def func_multi_producer() -> None: + A = tir.alloc_buffer((128)) + B = tir.alloc_buffer((128)) + with tir.block([128], "A0") as [vi]: + A[vi] = 1.0 + with tir.block([128], "A1") as [vi]: + A[vi] = 2.0 + with tir.block([128], "B") as [vi]: + B[vi] = A[vi] + + +########## Expected function after cache_read ########## + + +@tvm.script.tir +def cache_read_elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + A_global = tir.alloc_buffer((128, 128)) + B_local = tir.alloc_buffer((128, 128), scope="local") + with tir.block([128, 128], "A_global") as [vi, vj]: + A_global[vi, vj] = A[vi, vj] + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A_global[vi, vj] * 2.0 + with tir.block([128, 128], "B_local") as [vi, vj]: + B_local[vi, vj] = B[vi, vj] + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B_local[vi, vj] + 1.0 + + +@tvm.script.tir +def cache_read_under_scope(b: ty.handle, c: ty.handle) -> None: + A = tir.alloc_buffer((128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + A_global = tir.alloc_buffer((128, 128)) + + with tir.block([8, 8], "scope") as [i, j]: + A_local = tir.alloc_buffer((128, 128), scope="local") + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "A") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + A[vi, vj] = 1.0 + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "A_local") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + A_local[vi, vj] = A[vi, vj] + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + B[vi, vj] = A_local[vi, vj] + 1.0 + with tir.block([128, 128], "A_global") as [vi, vj]: + A_global[vi, vj] = A[vi, vj] + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A_global[vi, vj] * 2.0 + + +@tvm.script.tir +def cache_read_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), dtype="float16") + B = tir.match_buffer(b, (128, 128), dtype="float16") + C = tir.match_buffer(c, (128, 128), dtype="float16") + D = tir.match_buffer(d, (128, 128), dtype="float16") + A_global = tir.alloc_buffer((128, 128), dtype="float16") + + with tir.block([128, 128], "A_global") as [vi, vj]: + A_global[vi, vj] = A[vi, vj] + with tir.block([128, 128], "load_store") as [vi, vj]: + tir.reads(A_global[vi, vj]) + tir.writes(D[vi, vj]) + D.data[vi * 128 + vj] = tir.load("float16", A_global.data, vi * 128 + vj) + with tir.block([8, 8], "opaque") as [vi, vj]: + tir.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.evaluate( + tir.tvm_load_matrix_sync( + B.data, + 16, + 16, + 16, + vi * 8 + vj, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A_global.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), + 128, + "row_major", + dtype="handle", + ) + ) + with tir.block([8, 8], "match_buffer") as [vi, vj]: + tir.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = tir.match_buffer( + A_global[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + C0 = tir.match_buffer( + C[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", + dtype="handle", + ) + ) + + +@tvm.script.tir +def cache_read_multi_consumer() -> None: + A = tir.alloc_buffer((128)) + B = tir.alloc_buffer((128)) + C = tir.alloc_buffer((128)) + A_global = tir.alloc_buffer((128)) + for i in tir.grid(8): + for j in tir.grid(16): + with tir.block([128], "A") as [vi]: + tir.bind(vi, i * 16 + j) + A[vi] = 1.0 + for j in tir.grid(16): + with tir.block([128], "A") as [vi]: + tir.bind(vi, i * 16 + j) + A_global[vi] = A[vi] + for j in tir.grid(16): + with tir.block([128], "B") as [vi]: + tir.bind(vi, i * 16 + j) + B[vi] = A_global[vi] + 1.0 + + for i in tir.grid(128): + with tir.block([128], "C") as [vi]: + C[vi] = A_global[vi] + + +@tvm.script.tir +def continuous_cache_read(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + B_shared = tir.alloc_buffer((128, 128), scope="shared") + B_local = tir.alloc_buffer((128, 128), scope="local") + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "B_shared") as [vi, vj]: + B_shared[vi, vj] = B[vi, vj] + with tir.block([128, 128], "B_local") as [vi, vj]: + B_local[vi, vj] = B_shared[vi, vj] + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B_local[vi, vj] + 1.0 + + +########## Expected function after cache_write ########## + + +@tvm.script.tir +def cache_write_elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + B_global = tir.alloc_buffer((128, 128), scope="local") + C_local = tir.alloc_buffer((128, 128)) + with tir.block([128, 128], "B_global") as [vi, vj]: + B_global[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = B_global[vi, vj] + with tir.block([128, 128], "C_local") as [vi, vj]: + C_local[vi, vj] = B[vi, vj] + 1.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = C_local[vi, vj] + + +@tvm.script.tir +def cache_write_under_scope(b: ty.handle, c: ty.handle) -> None: + A = tir.alloc_buffer((128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + A_global = tir.alloc_buffer((128, 128)) + + with tir.block([8, 8], "scope") as [i, j]: + A_local = tir.alloc_buffer((128, 128), scope="local") + B_global = tir.alloc_buffer((128, 128)) + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "A_local") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + A_local[vi, vj] = 1.0 + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "A") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + A_global[vi, vj] = A_local[vi, vj] + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "B_global") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + B_global[vi, vj] = A_global[vi, vj] + 1.0 + for x, y in tir.grid(16, 16): + with tir.block([128, 128], "B_global") as [vi, vj]: + tir.bind(vi, i * 16 + x) + tir.bind(vj, j * 16 + y) + B[vi, vj] = B_global[vi, vj] + with tir.block([128, 128], "A_global") as [vi, vj]: + A[vi, vj] = A_global[vi, vj] + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def cache_write_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), dtype="float16") + B = tir.match_buffer(b, (128, 128), dtype="float16") + C = tir.match_buffer(c, (128, 128), dtype="float16") + D = tir.match_buffer(d, (128, 128), dtype="float16") + D_global = tir.alloc_buffer((128, 128), dtype="float16") + B_global = tir.alloc_buffer((128, 128), dtype="float16") + C_global = tir.alloc_buffer((128, 128), dtype="float16") + + with tir.block([128, 128], "load_store") as [vi, vj]: + tir.reads(A[vi, vj]) + tir.writes(D_global[vi, vj]) + D_global.data[vi * 128 + vj] = tir.load("float16", A.data, vi * 128 + vj) + with tir.block([8, 8], "opaque") as [vi, vj]: + tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.evaluate( + tir.tvm_load_matrix_sync( + B_global.data, + 16, + 16, + 16, + vi * 8 + vj, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A.data, + vi * 2048 + vj * 16, + 128, + 1, + dtype="handle", + ), + 128, + "row_major", + dtype="handle", + ) + ) + with tir.block([8, 8], "match_buffer") as [vi, vj]: + tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + tir.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = tir.match_buffer( + A[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + C0 = tir.match_buffer( + C_global[ + vi * 16 : vi * 16 + 16, + vj * 16 : vj * 16 + 16, + ], + (16, 16), + "float16", + strides=[128, 1], + offset_factor=1, + ) + tir.evaluate( + tir.tvm_load_matrix_sync( + C0.data, + 16, + 16, + 16, + vi * 8 + vj, + tir.tvm_access_ptr( + tir.type_annotation(dtype="float16"), + A0.data, + A0.elem_offset, + A0.strides[0], + 1, + dtype="handle", + ), + 128, + "row_major", + dtype="handle", + ) + ) + + with tir.block([128, 128], "D") as [vi, vj]: + D[vi, vj] = D_global[vi, vj] + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = B_global[vi, vj] + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = C_global[vi, vj] + + +@tvm.script.tir +def cache_write_multi_consumer() -> None: + A = tir.alloc_buffer((128)) + B = tir.alloc_buffer((128)) + C = tir.alloc_buffer((128)) + A_global = tir.alloc_buffer((128)) + for i in tir.grid(8): + for j in tir.grid(16): + with tir.block([128], "A_global") as [vi]: + tir.bind(vi, i * 16 + j) + A_global[vi] = 1.0 + for j in tir.grid(16): + with tir.block([128], "A") as [vi]: + tir.bind(vi, i * 16 + j) + A[vi] = A_global[vi] + for j in tir.grid(16): + with tir.block([128], "B") as [vi]: + tir.bind(vi, i * 16 + j) + B[vi] = A[vi] + 1.0 + + for i in tir.grid(128): + with tir.block([128], "C") as [vi]: + C[vi] = A[vi] + + +@tvm.script.tir +def continuous_cache_write(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + B_shared = tir.alloc_buffer((128, 128), scope="shared") + B_local = tir.alloc_buffer((128, 128), scope="local") + with tir.block([128, 128], "B") as [vi, vj]: + B_local[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "B") as [vi, vj]: + B_shared[vi, vj] = B_local[vi, vj] + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = B_shared[vi, vj] + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +########## Testcases for cache_read ########## + + +def test_cache_read_elementwise(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + block_c = sch.get_block("C") + cached_a = sch.cache_read(block_b, 0, "global") + cached_b = sch.cache_read(block_c, 0, "local") + assert sch.get(cached_a) == sch.get(sch.get_block("A_global")) + assert sch.get(cached_b) == sch.get(sch.get_block("B_local")) + assert sch.get(block_b) == sch.get(sch.get_block("B")) + assert sch.get(block_c) == sch.get(sch.get_block("C")) + tvm.ir.assert_structural_equal(cache_read_elementwise, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_cache_read_under_scope(): + sch = tir.Schedule(access_under_scope, debug_mask="all") + block_b = sch.get_block("B") + block_c = sch.get_block("C") + sch.cache_read(block_b, 0, "local") + sch.cache_read(block_c, 0, "global") + tvm.ir.assert_structural_equal(cache_read_under_scope, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=access_under_scope) + + +def test_cache_read_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mask="all") + block = sch.get_block("load_store") + sch.cache_read(block, 0, "global") + tvm.ir.assert_structural_equal(cache_read_opaque_access, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=opaque_access) + + +def test_cache_read_location(): + sch = tir.Schedule(func_multi_consumer, debug_mask="all") + block_b = sch.get_block("B") + sch.cache_read(block_b, 0, "global") + tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) + + +def test_continuous_cache_read(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_c = sch.get_block("C") + sch.cache_read(block_c, 0, "shared") + sch.cache_read(block_c, 0, "local") + tvm.ir.assert_structural_equal(continuous_cache_read, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_cache_read_fail_multi_producer(): + sch = tir.Schedule(func_multi_producer, debug_mask="all") + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.cache_read(block_b, 0, "global") + + +def test_cache_read_fail_index_out_of_bound(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.cache_read(block_b, 1, "global") + + +########## Testcases for cache_write ########## + + +def test_cache_write_elementwise(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + block_c = sch.get_block("C") + cached_b = sch.cache_write(block_b, 0, "local") + cached_c = sch.cache_write(block_c, 0, "global") + assert sch.get(cached_b) == sch.get(sch.get_block("B_local")) + assert sch.get(cached_c) == sch.get(sch.get_block("C_global")) + assert sch.get(block_b) == sch.get(sch.get_block("B")) + assert sch.get(block_c) == sch.get(sch.get_block("C")) + tvm.ir.assert_structural_equal(cache_write_elementwise, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_cache_write_under_scope(): + sch = tir.Schedule(access_under_scope, debug_mask="all") + block_a = sch.get_block("A") + block_b = sch.get_block("B") + block_scope = sch.get_block("scope") + sch.cache_write(block_a, 0, "local") + sch.cache_write(block_b, 0, "global") + sch.cache_write(block_scope, 0, "global") + tvm.ir.assert_structural_equal(cache_write_under_scope, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=access_under_scope) + + +def test_cache_write_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mask="all") + block_store = sch.get_block("load_store") + block_opaque = sch.get_block("opaque") + block_match_buffer = sch.get_block("match_buffer") + sch.cache_write(block_store, 0, "global") + sch.cache_write(block_opaque, 0, "global") + sch.cache_write(block_match_buffer, 0, "global") + tvm.ir.assert_structural_equal(cache_write_opaque_access, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=opaque_access) + + +def test_cache_write_location(): + sch = tir.Schedule(func_multi_consumer, debug_mask="all") + block_a = sch.get_block("A") + sch.cache_write(block_a, 0, "global") + tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=func_multi_consumer) + + +def test_continuous_cache_write(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + sch.cache_write(block_b, 0, "shared") + sch.cache_write(block_b, 0, "local") + tvm.ir.assert_structural_equal(continuous_cache_write, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_cache_write_fail_multi_producer(): + sch = tir.Schedule(func_multi_producer, debug_mask="all") + block_a0 = sch.get_block("A0") + block_a1 = sch.get_block("A1") + with pytest.raises(tvm.tir.ScheduleError): + sch.cache_write(block_a0, 0, "global") + with pytest.raises(tvm.tir.ScheduleError): + sch.cache_write(block_a1, 0, "global") + + +def test_cache_write_fail_index_out_of_bound(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.cache_write(block_b, 1, "global") + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))