Skip to content

Commit a06863a

Browse files
vinx13Siyuan FengspectrometerHBHMasterJH5574jinhongyii
authored
[TensorIR][M2a] Storage Align (apache#8693)
This PR is part of the TensorIR upstreaming effort (apache#7527), which adds the one schedule primitive storage_align. Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Junru Shao <junrushao1994@gmail.com>
1 parent ccc09fa commit a06863a

File tree

16 files changed

+882
-3
lines changed

16 files changed

+882
-3
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,21 @@ class ScheduleNode : public runtime::Object {
264264
* \return The rfactor block
265265
*/
266266
virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
267+
/******** Schedule: Block annotation ********/
268+
/*!
269+
* \brief Set alignment requirement for specific dimension such that
270+
* stride[axis] == k * factor + offset for some k. This is useful to set memory layout for
271+
* more friendly memory access pattern. For example, we can set alignment to be factor=2,
272+
* offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared
273+
* memory.
274+
* \param block_rv The producer block of the buffer
275+
* \param buffer_index The index of the buffer in block's write region
276+
* \param axis The dimension to be specified for alignment
277+
* \param factor The factor multiple of alignment
278+
* \param offset The required offset factor
279+
*/
280+
virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
281+
int offset) = 0;
267282
/******** Schedule: Blockize & Tensorize ********/
268283
/******** Schedule: Annotation ********/
269284
/******** Schedule: Misc ********/

python/tvm/tir/schedule/schedule.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,79 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None:
710710
"""
711711
return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member
712712

713+
######## Schedule: Block annotatoin ########
714+
715+
def storage_align( # pylint: disable=too-many-arguments
716+
self, block: BlockRV, buffer_index: int, axis: int, factor: int, offset: int
717+
) -> None:
718+
"""Set alignment requirement for specific dimension such that
719+
stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more
720+
friendly memory access pattern. For example, we can set alignment to be factor=2, offset=1
721+
to avoid bank conflict for thread access on higher dimension in GPU shared memory.
722+
723+
Parameters
724+
----------
725+
block : BlockRV
726+
The producer block of the buffer.
727+
buffer_index : int
728+
The index of the buffer in block's write region.
729+
axis : int
730+
The dimension to be specified for alignment.
731+
factor : int
732+
The factor multiple of alignment.
733+
offset : int
734+
The required offset factor.
735+
736+
Examples
737+
--------
738+
739+
Before storage_align, in TensorIR, the IR is:
740+
741+
.. code-block:: python
742+
743+
@tvm.script.tir
744+
def before_storage_align(a: ty.handle, c: ty.handle) -> None:
745+
A = tir.match_buffer(a, (128, 128))
746+
B = tir.alloc_buffer((128, 128))
747+
C = tir.match_buffer(c, (128, 128))
748+
with tir.block([128, 128], "B") as [vi, vj]:
749+
B[vi, vj] = A[vi, vj] * 2.0
750+
with tir.block([128, 128], "C") as [vi, vj]:
751+
C[vi, vj] = B[vi, vj] + 1.0
752+
753+
Create the schedule and do storage_align:
754+
755+
.. code-block:: python
756+
757+
sch = tir.Schedule(before_storage_align)
758+
sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1)
759+
print(tvm.script.asscript(sch.mod["main"]))
760+
761+
After applying rfactor, the IR becomes:
762+
763+
.. code-block:: python
764+
765+
@tvm.script.tir
766+
def after_storage_align(a: ty.handle, c: ty.handle) -> None:
767+
A = tir.match_buffer(a, (128, 128))
768+
B = tir.alloc_buffer((128, 128))
769+
C = tir.match_buffer(c, (128, 128))
770+
with tir.block([128, 128], "B") as [vi, vj]:
771+
tir.block_attr({"buffer_dim_align": [[[0, 128, 1]]]})
772+
B[vi, vj] = A[vi, vj] * 2.0
773+
with tir.block([128, 128], "C") as [vi, vj]:
774+
C[vi, vj] = B[vi, vj] + 1.0
775+
776+
After lowering passes, buffer B will have strides as [129, 1].
777+
778+
Note
779+
----
780+
Storage_align requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
781+
"""
782+
_ffi_api.ScheduleStorageAlign( # type: ignore # pylint: disable=no-member
783+
self, block, buffer_index, axis, factor, offset
784+
)
785+
713786
########## Schedule: Blockize & Tensorize ##########
714787

715788
########## Schedule: Annotation ##########

src/tir/schedule/analysis.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,19 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self
202202
*/
203203
BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref);
204204

205+
/******** Block-buffer relation ********/
206+
207+
/*!
208+
* \brief Get the BlockRealize of the single child block of the block or loop specified by
209+
* `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks
210+
* \param self The schedule state
211+
* \param block The queried block
212+
* \param n The index of the queried buffer
213+
* \return The buffer of the n-th write region of the block.
214+
* \throw ScheduleError If the buffer index is out of bound.
215+
*/
216+
Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n);
217+
205218
/******** Commutative Reducer ********/
206219

207220
/*!

src/tir/schedule/analysis/analysis.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,45 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
527527
}
528528
}
529529

530+
/******** Block-buffer relation ********/
531+
532+
Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) {
533+
class WriteBufferIndexOutOfRangeError : public ScheduleError {
534+
public:
535+
explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index)
536+
: mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {}
537+
538+
String FastErrorString() const final {
539+
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
540+
"range [0, num_write_regions) where `num_write_regions` is the number of buffer "
541+
"regions written by the block.";
542+
}
543+
544+
String DetailRenderTemplate() const final {
545+
std::ostringstream os;
546+
size_t num_writes = block_->writes.size();
547+
os << "The block {0} has " << num_writes
548+
<< " write regions, so `buffer_index` is required to be in [0, " << num_writes
549+
<< "). However, the input `buffer_index` is " << buffer_index_
550+
<< ", which is out of the expected range";
551+
return os.str();
552+
}
553+
554+
IRModule mod() const final { return mod_; }
555+
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
556+
557+
private:
558+
IRModule mod_;
559+
Block block_;
560+
int buffer_index_;
561+
};
562+
563+
if (n < 0 || static_cast<size_t>(n) >= block->writes.size()) {
564+
throw WriteBufferIndexOutOfRangeError(self->mod, block, n);
565+
}
566+
return block->writes[n]->buffer;
567+
}
568+
530569
/******** Pattern Matcher ********/
531570

532571
/*!

src/tir/schedule/concrete_schedule.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,16 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) {
362362
}
363363

364364
/******** Schedule: loop binding/annotation ********/
365+
/******** Schedule: block annotation ********/
366+
367+
void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis,
368+
int factor, int offset) {
369+
TVM_TIR_SCHEDULE_BEGIN();
370+
tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis, factor, offset);
371+
TVM_TIR_SCHEDULE_END("storage-align", this->error_render_level_);
372+
this->state_->DebugVerify();
373+
}
374+
365375
/******** Schedule: cache read/write ********/
366376
/******** Schedule: reduction ********/
367377

src/tir/schedule/concrete_schedule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class ConcreteScheduleNode : public ScheduleNode {
8888
void ReverseComputeInline(const BlockRV& block) override;
8989
/******** Schedule: Reduction ********/
9090
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
91+
/******** Schedule: Block annotation ********/
92+
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
93+
int offset) override;
9194
/******** Schedule: Blockize & Tensorize ********/
9295
/******** Schedule: Annotation ********/
9396
/******** Schedule: Misc ********/

src/tir/schedule/primitive.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,26 @@ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref
104104
* \return The sref of the rfactor block
105105
*/
106106
TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis);
107+
/******** Schedule: Block annotation ********/
108+
/*!
109+
* \brief Set alignment requirement for specific dimension such that
110+
* stride[axis] == k * factor + offset for some k. This is useful to set memory layout for
111+
* more friendly memory access pattern. For example, we can set alignment to be factor=2,
112+
* offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared
113+
* memory.
114+
* \param block_sref The producer block of the buffer
115+
* \param buffer_index The index of the buffer in block's write region
116+
* \param axis The dimension to be specified for alignment
117+
* \param factor The factor multiple of alignment
118+
* \param offset The required offset factor
119+
*/
120+
TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
121+
int axis, int factor, int offset);
122+
123+
/******** Annotation types for StorageAlign ********/
124+
using StorageAlignTuple = Array<Integer>; // (buffer_idx, axis, factor, offset)
125+
using StorageAlignAnnotation = Array<StorageAlignTuple>; // unordered array of StorageAlignTuple
126+
107127
/******** Schedule: Blockize & Tensorize ********/
108128
/******** Schedule: Annotation ********/
109129
/******** Schedule: Misc ********/

0 commit comments

Comments
 (0)