Skip to content

Commit

Permalink
[TensorIR][M2a] CacheRead/Write (apache#8863)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
  • Loading branch information
6 people authored and ylc committed Jan 13, 2022
1 parent 05961ec commit 65cbd5d
Show file tree
Hide file tree
Showing 18 changed files with 1,840 additions and 23 deletions.
22 changes: 22 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,28 @@ class ScheduleNode : public runtime::Object {
*/
virtual void Unroll(const LoopRV& loop_rv) = 0;
/******** Schedule: Insert cache stages ********/
/*!
* \brief Create a block that reads a buffer region into a read cache. It requires:
* 1) There is at most one block who writes the buffer in the scope.
* 2) The scope block have stage-pipeline property.
* \param block_rv The consumer block of the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope.
* \return The cache stage block.
*/
virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) = 0;
/*!
* \brief Create a block that writes a buffer region into a write cache. It requires:
* 1) There is only one block who writes the target buffer.
* 2) The scope block have stage-pipeline property.
* \param block_rv The producer of the buffer
* \param write_buffer_index The index of the buffer in block's write region
* \param storage_scope The target storage scope
* \return The cache stage block.
*/
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) = 0;
/******** Schedule: Compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ class ScheduleStateNode : public Object {
*/
TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt,
const Map<Block, Block>& block_sref_reuse);
/*!
* \brief Recalculate the `affine_binding` flag of the scope block info.
* \param scope_sref The sref to the interested scope block.
*/
TVM_DLL void UpdateAffineFlag(const StmtSRef& scope_sref);
/*!
* \brief Trigger the verification according to the `debug_mask` bitmask.
* 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree.
Expand Down
135 changes: 135 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,141 @@ def after_unroll(a: ty.handle, b: ty.handle) -> None:

########## Schedule: Insert cache stages ##########

def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) -> BlockRV:
"""Create a block that reads a buffer region into a read cache. It requires:
1) There is at most one block who write the buffer in the scope.
2) The scope block have stage-pipeline property.
Parameters
----------
block : BlockRV
The consumer block of the target buffer.
read_buffer_index: int
The index of the buffer in block's read region.
storage_scope: str
The target storage scope.
Returns
-------
cached_block : BlockRV
The block of the cache stage
Examples
--------
Before cache_read, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_cache_read(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and cache_read:
.. code-block:: python
sch = tir.Schedule(before_cache_read)
block_b = sch.get_block("B")
sch.cache_read(block_b, 0, "local")
print(tvm.script.asscript(sch.mod["main"]))
After applying cache_read, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_cache_read(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
A_local = tir.alloc_buffer((128, 128), scope="local")
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "A_local") as [vi, vj]:
A_local[vi, vj] = A[vi, vj]
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A_local[vi, vj] * 2.0
"""
return _ffi_api.ScheduleCacheRead( # type: ignore # pylint: disable=no-member
self, block, read_buffer_index, storage_scope
)

def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: str) -> BlockRV:
"""Create a block that reads a buffer region into a write cache. It requires:
1) There is only one block who write the buffer in the scope.
2) The scope block have stage-pipeline property.
Parameters
----------
block : BlockRV
The producer block of the target buffer.
write_buffer_index: int
The index of the buffer in block's write region.
storage_scope: str
The target storage scope.
Returns
-------
cached_block : BlockRV
The block of the cache stage
Examples
--------
Before cache_write, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_cache_write(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and cache_write:
.. code-block:: python
sch = tir.Schedule(before_cache_write)
block_b = sch.get_block("B")
sch.cache_write(block_b, 0, "local")
print(tvm.script.asscript(sch.mod["main"]))
After applying cache_write, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_cache_write(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
B_local = tir.alloc_buffer((128, 128), scope="local")
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "A_local") as [vi, vj]:
B_local[vi, vj] = A[vi, vj] * 2.0
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = B_local[vi, vj]
"""
return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member
self, block, write_buffer_index, storage_scope
)

########## Schedule: Compute location ##########

def compute_inline(self, block: BlockRV) -> None:
Expand Down
21 changes: 14 additions & 7 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ void VerifyCachedFlags(const ScheduleState& self);
const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
GlobalVar* result_g_var);

/*!
* \brief Get the root node of the sref tree, which is the root block of the PrimFunc.
* \param sref The given sref.
* \return The root node of the sref tree which contains the given node.
*/
StmtSRef GetSRefTreeRoot(const StmtSRef& sref);

/******** Scope ********/
/*!
* \brief Checks if scope the specified sref is in is a stage-pipeline and return it
Expand Down Expand Up @@ -228,15 +235,15 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
/******** Block-buffer relation ********/

/*!
* \brief Get the BlockRealize of the single child block of the block or loop specified by
* `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks
* \param self The schedule state
* \param block The queried block
* \param n The index of the queried buffer
* \return The buffer of the n-th write region of the block.
* \brief Get the n-th read or write buffer of the given block.
* \param self The schedule state.
* \param block The queried block.
* \param n The index of the queried buffer.
* \param is_write A boolean flag to indicate querying write buffer or read buffer.
* \return The buffer of the n-th read/write region of the block.
* \throw ScheduleError If the buffer index is out of bound.
*/
Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n);
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write);

/******** Commutative Reducer ********/

Expand Down
50 changes: 36 additions & 14 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -588,25 +588,37 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr

/******** Block-buffer relation ********/

Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) {
class WriteBufferIndexOutOfRangeError : public ScheduleError {
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write) {
class BufferIndexOutOfRangeError : public ScheduleError {
public:
explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index)
: mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {}
explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index, bool is_write)
: mod_(std::move(mod)),
block_(std::move(block)),
buffer_index_(buffer_index),
is_write_(is_write) {}

String FastErrorString() const final {
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
"range [0, num_write_regions) where `num_write_regions` is the number of buffer "
"regions written by the block.";
if (is_write_) {
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
"range "
"[0, num_write_regions) where `num_write_regions` is the number of buffer regions "
"written by the block.";
} else {
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
"range "
"[0, num_read_regions) where `num_read_regions` is the number of buffer regions "
"read by the block.";
}
}

String DetailRenderTemplate() const final {
std::ostringstream os;
size_t num_writes = block_->writes.size();
os << "The block {0} has " << num_writes
<< " write regions, so `buffer_index` is required to be in [0, " << num_writes
size_t num = is_write_ ? block_->writes.size() : block_->reads.size();
std::string access_type = is_write_ ? "write" : "read";
os << "The block {0} has " << num << " " << access_type
<< " regions, so `buffer_index` is required to be in [0, " << num
<< "). However, the input `buffer_index` is " << buffer_index_
<< ", which is out of the expected range";
<< ", which is out of the expected range.";
return os.str();
}

Expand All @@ -617,12 +629,15 @@ Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) {
IRModule mod_;
Block block_;
int buffer_index_;
bool is_write_;
};

if (n < 0 || static_cast<size_t>(n) >= block->writes.size()) {
throw WriteBufferIndexOutOfRangeError(self->mod, block, n);
const Array<BufferRegion>& access_region = is_write ? block->writes : block->reads;

if (n < 0 || static_cast<int>(access_region.size()) <= n) {
throw BufferIndexOutOfRangeError(self->mod, block, n, is_write);
}
return block->writes[n]->buffer;
return access_region[n]->buffer;
}

/******** Pattern Matcher ********/
Expand Down Expand Up @@ -941,5 +956,12 @@ bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner,
return false;
}

/******** SRef Tree Related ********/
StmtSRef GetSRefTreeRoot(const StmtSRef& sref) {
const StmtSRefNode* p = sref.get();
for (; p->parent != nullptr; p = p->parent) {
}
return GetRef<StmtSRef>(p);
}
} // namespace tir
} // namespace tvm
21 changes: 21 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,27 @@ void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) {
}

/******** Schedule: Insert cache stages ********/

BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope);
TVM_TIR_SCHEDULE_END("cache-read", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope);
TVM_TIR_SCHEDULE_END("cache-write", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

/******** Schedule: Compute location ********/

void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) {
Expand Down
4 changes: 4 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class ConcreteScheduleNode : public ScheduleNode {
void Bind(const LoopRV& loop_rv, const String& thread_axis) override;
void Unroll(const LoopRV& loop_rv) override;
/******** Schedule: Insert cache stages ********/
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) override;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
/******** Schedule: Compute location ********/
void ComputeInline(const BlockRV& block) override;
void ReverseComputeInline(const BlockRV& block) override;
Expand Down
24 changes: 24 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,30 @@ TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar&
*/
TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref);
/******** Schedule: Insert cache stages ********/
/*!
* \brief Create a block that reads a buffer region into a read cache. It requires:
* 1) There is at most one block who writes the buffer in the scope.
* 2) The scope block have stage-pipeline property.
* \param self The state of the schedule
* \param block_sref The consumer block of the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope.
* \return The cache stage block.
*/
TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index,
const String& storage_scope);
/*!
* \brief Create a block that writes a buffer region into a write cache. It requires:
* 1) There is only one block that writes the target buffer.
* 2) The scope block have stage-pipeline property.
* \param self The state of the schedule
* \param block_sref The producer of the buffer
* \param write_buffer_index The index of the buffer in block's write region
* \param storage_scope The target storage scope
* \return The cache stage block.
*/
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
const String& storage_scope);
/******** Schedule: Compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/primitive/block_annotate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
#include "../transform.h"
#include "../utils.h"

namespace tvm {
Expand Down Expand Up @@ -237,7 +236,8 @@ class StorageAlignInvalidAnnotationError : public ScheduleError {
void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis,
int factor, int offset) {
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
Buffer buffer = GetNthWriteBuffer(self, GetRef<Block>(block_ptr), buffer_index);
Buffer buffer =
GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, /*is_write=*/true);
StorageAlignInvalidFactorError::Check(self->mod, factor);
axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis);
NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer);
Expand Down
Loading

0 comments on commit 65cbd5d

Please sign in to comment.