Skip to content

Commit

Permalink
[TIR][Primitive] Support rolling_buffer schedule primitive in TensorIR (
Browse files Browse the repository at this point in the history
#13033)

* [TIR][Primitive] Support rolling_buffer schedule primitive in TensorIR

* Address review comments

* Add dependency checks
  • Loading branch information
liangW-intellif authored Nov 1, 2022
1 parent 9cdc97f commit 2c1fecd
Show file tree
Hide file tree
Showing 10 changed files with 1,219 additions and 0 deletions.
17 changes: 17 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,23 @@ class ScheduleNode : public runtime::Object {
*/
virtual void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) = 0;

/******** Schedule: Buffer transformation ********/
/*!
* \brief Compute the target buffer via rolling buffering.
* \details This primitive selects the outermost rollable axis with a positive bound overlap that
* appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along
* the rolling dimension, append block predicate to avoid recomputing overlapping elements.
* It requires:
* 1) The buffer to be an intermediate buffer defined via `alloc_buffer`.
* 2) The LCA of the producer and consumer of the buffer is a for loop, typically,
* the producer and consumer of the buffer are cascaded through compute_at.
* 3) The access region of the buffer has at least one dimension that contains
* a positive bound overlap.
* \param block_rv The producer block of the buffer.
* \param write_buffer_index The index of the buffer in block's write region.
*/
virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0;

/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
Expand Down
108 changes: 108 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3128,6 +3128,114 @@ def after_pad_einsum(
self, block, padding
)

######## Schedule: Buffer transformation ########

@type_checked
def rolling_buffer(
self,
block: Union[BlockRV, str],
write_buffer_index: int,
) -> None:
"""Compute the target buffer via rolling buffering, select the outermost rollable
axis with a positive bound overlap that appears in the block's ancestor loops
as `rolling axis`, fold and circularize the buffer along the rolling dimension,
append block predicate to avoid recomputing overlapping elements. It requires:
1) The block is not an output block and has only RAW dependencies.
2) The buffer to be an intermediate buffer defined via `alloc_buffer`.
3) The LCA of the producer and consumer of the buffer is a for loop, typically,
the producer and consumer of the buffer are cascaded through compute_at.
4) The access region of the buffer has at least one dimension that contains
a positive bound overlap.
Parameters
----------
block : Union[BlockRV, str]
The producer block of the buffer.
write_buffer_index : int
The index of the buffer in block's write region.
Examples
--------
Before rolling_buffer, in TensorIR, the IR is:
.. code-block:: python
@T.prim_func
def before_rolling_buffer(
A: T.Buffer[(12, 12), "int8"], C: T.Buffer[(8, 8), "int8"]
) -> None:
# body
# with T.block("root")
B = T.alloc_buffer([10, 10], dtype="int8")
for i0, i1 in T.grid(2, 2):
for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3):
with T.block("B"):
ax0_1 = T.axis.spatial(10, i0 * 4 + ax0)
ax1_1 = T.axis.spatial(10, i1 * 4 + ax1)
rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
B[ax0_1, ax1_1] = T.max(
B[ax0_1, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1]
)
for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3):
with T.block("C"):
ax0_1 = T.axis.spatial(8, i0 * 4 + ax0)
ax1_1 = T.axis.spatial(8, i1 * 4 + ax1)
rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
C[ax0_1, ax1_1] = T.max(
C[ax0_1, ax1_1], B[ax0_1 + rv0, ax1_1 + rv1]
)
Create the schedule and do rolling_buffer:
.. code-block:: python
sch = tir.Schedule(before_rolling_buffer)
sch.rolling_buffer(sch.get_block("B"), write_buffer_index=0)
print(sch.mod["main"].script())
After applying rolling_buffer, the IR becomes:
.. code-block:: python
@T.prim_func
def after_rolling_buffer(
A: T.Buffer[(12, 12), "int8"],
C: T.Buffer[(8, 8), "int8"]
) -> None:
# body
# with T.block("root")
B = T.alloc_buffer([6, 10], dtype="int8")
for i0, i1 in T.grid(2, 2):
for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3):
with T.block("B"):
T.where((i0 < 1 or 2 <= ax0) and (i1 < 1 or 2 <= ax1))
ax0_1 = T.axis.spatial(10, i0 * 4 + ax0)
ax1_1 = T.axis.spatial(10, i1 * 4 + ax1)
rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
B[ax0_1 % 6, ax1_1] = T.max(
B[ax0_1 % 6, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1]
)
for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3):
with T.block("C"):
ax0_1 = T.axis.spatial(8, i0 * 4 + ax0)
ax1_1 = T.axis.spatial(8, i1 * 4 + ax1)
rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
C[ax0_1, ax1_1] = T.max(
C[ax0_1, ax1_1], B[ax0_1 % 6 + rv0, ax1_1 + rv1]
)
Note
----
The region_cover property of the consumer block of the target buffer will become false.
"""
block = self._normalize_block_arg(block)
return _ffi_api.ScheduleRollingBuffer(self, block, write_buffer_index) # type: ignore # pylint: disable=no-member

########## Schedule: Misc ##########

@type_checked
Expand Down
12 changes: 12 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,8 @@ void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_
this->state_->DebugVerify();
}

/******** Schedule: Padding ********/

BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
Expand All @@ -829,6 +831,16 @@ void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const Array<Intege
TVM_TIR_SCHEDULE_END("pad-einsum", this->error_render_level_);
this->state_->DebugVerify();
}

/******** Schedule: Buffer Transformation ********/

void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buffer_index) {
TVM_TIR_SCHEDULE_BEGIN();
tir::RollingBuffer(state_, this->GetSRef(block_rv), write_buffer_index);
TVM_TIR_SCHEDULE_END("rolling-buffer", this->error_render_level_);
this->state_->DebugVerify();
}

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ class ConcreteScheduleNode : public ScheduleNode {
const Array<IntImm>& axis_separators) override;
/******** Schedule: Padding decomposition ********/
BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override;
/******** Schedule: Buffer transformation ********/
void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) override;
/******** Schedule: Misc ********/
void EnterPostproc() override {}

Expand Down
16 changes: 16 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,22 @@ TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref
TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref,
const Array<Integer>& padding);

/******** Schedule: Buffer transformation ********/
/*!
* \brief Compute the target buffer via rolling buffering.
* \details This primitive selects the outermost rollable axis with a positive bound overlap that
* appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along
* the rolling dimension, append block predicate to avoid recomputing overlapping elements.
* It requires:
* 1) The buffer to be an intermediate buffer defined via `alloc_buffer`.
* 2) The LCA of the producer and consumer of the buffer is a for loop, typically,
* the producer and consumer of the buffer are cascaded through compute_at.
* 3) The access region of the buffer has at least one dimension that contains
* a positive bound overlap.
* \param block_rv The producer block of the buffer.
* \param write_buffer_index The index of the buffer in block's write region.
*/
TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index);
/******** Schedule: Misc ********/

} // namespace tir
Expand Down
Loading

0 comments on commit 2c1fecd

Please sign in to comment.