@@ -710,6 +710,79 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None:
710710 """
711711 return _ffi_api .ScheduleRFactor (self , loop , factor_axis ) # type: ignore # pylint: disable=no-member
712712
713+ ######## Schedule: Block annotatoin ########
714+
715+ def storage_align ( # pylint: disable=too-many-arguments
716+ self , block : BlockRV , buffer_index : int , axis : int , factor : int , offset : int
717+ ) -> None :
718+ """Set alignment requirement for specific dimension such that
719+ stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more
720+ friendly memory access pattern. For example, we can set alignment to be factor=2, offset=1
721+ to avoid bank conflict for thread access on higher dimension in GPU shared memory.
722+
723+ Parameters
724+ ----------
725+ block : BlockRV
726+ The producer block of the buffer.
727+ buffer_index : int
728+ The index of the buffer in block's write region.
729+ axis : int
730+ The dimension to be specified for alignment.
731+ factor : int
732+ The factor multiple of alignment.
733+ offset : int
734+ The required offset factor.
735+
736+ Examples
737+ --------
738+
739+ Before storage_align, in TensorIR, the IR is:
740+
741+ .. code-block:: python
742+
743+ @tvm.script.tir
744+ def before_storage_align(a: ty.handle, c: ty.handle) -> None:
745+ A = tir.match_buffer(a, (128, 128))
746+ B = tir.alloc_buffer((128, 128))
747+ C = tir.match_buffer(c, (128, 128))
748+ with tir.block([128, 128], "B") as [vi, vj]:
749+ B[vi, vj] = A[vi, vj] * 2.0
750+ with tir.block([128, 128], "C") as [vi, vj]:
751+ C[vi, vj] = B[vi, vj] + 1.0
752+
753+ Create the schedule and do storage_align:
754+
755+ .. code-block:: python
756+
757+ sch = tir.Schedule(before_storage_align)
758+ sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1)
759+ print(tvm.script.asscript(sch.mod["main"]))
760+
761+ After applying rfactor, the IR becomes:
762+
763+ .. code-block:: python
764+
765+ @tvm.script.tir
766+ def after_storage_align(a: ty.handle, c: ty.handle) -> None:
767+ A = tir.match_buffer(a, (128, 128))
768+ B = tir.alloc_buffer((128, 128))
769+ C = tir.match_buffer(c, (128, 128))
770+ with tir.block([128, 128], "B") as [vi, vj]:
771+ tir.block_attr({"buffer_dim_align": [[[0, 128, 1]]]})
772+ B[vi, vj] = A[vi, vj] * 2.0
773+ with tir.block([128, 128], "C") as [vi, vj]:
774+ C[vi, vj] = B[vi, vj] + 1.0
775+
776+ After lowering passes, buffer B will have strides as [129, 1].
777+
778+ Note
779+ ----
780+ Storage_align requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
781+ """
782+ _ffi_api .ScheduleStorageAlign ( # type: ignore # pylint: disable=no-member
783+ self , block , buffer_index , axis , factor , offset
784+ )
785+
713786 ########## Schedule: Blockize & Tensorize ##########
714787
715788 ########## Schedule: Annotation ##########
0 commit comments