@@ -790,6 +790,141 @@ def after_unroll(a: ty.handle, b: ty.handle) -> None:
790790
791791 ########## Schedule: Insert cache stages ##########
792792
793+ def cache_read (self , block : BlockRV , read_buffer_index : int , storage_scope : str ) -> BlockRV :
794+ """Create a block that reads a buffer region into a read cache. It requires:
795+
796+ 1) There is at most one block who write the buffer in the scope.
797+
798+ 2) The scope block have stage-pipeline property.
799+
800+ Parameters
801+ ----------
802+ block : BlockRV
803+ The consumer block of the target buffer.
804+
805+ read_buffer_index: int
806+ The index of the buffer in block's read region.
807+
808+ storage_scope: str
809+ The target storage scope.
810+
811+ Returns
812+ -------
813+ cached_block : BlockRV
814+ The block of the cache stage
815+
816+ Examples
817+ --------
818+ Before cache_read, in TensorIR, the IR is:
819+
820+ .. code-block:: python
821+
822+ @tvm.script.tir
823+ def before_cache_read(a: ty.handle, b: ty.handle) -> None:
824+ A = tir.match_buffer(a, (128, 128))
825+ B = tir.match_buffer(b, (128, 128))
826+ for i, j in tir.grid(128, 128):
827+ with tir.block([128, 128], "B") as [vi, vj]:
828+ B[vi, vj] = A[vi, vj] * 2.0
829+
830+ Create the schedule and cache_read:
831+
832+ .. code-block:: python
833+
834+ sch = tir.Schedule(before_cache_read)
835+ block_b = sch.get_block("B")
836+ sch.cache_read(block_b, 0, "local")
837+ print(tvm.script.asscript(sch.mod["main"]))
838+
839+ After applying cache_read, the IR becomes:
840+
841+ .. code-block:: python
842+
843+ @tvm.script.tir
844+ def after_cache_read(a: ty.handle, b: ty.handle) -> None:
845+ A = tir.match_buffer(a, (128, 128))
846+ B = tir.match_buffer(b, (128, 128))
847+ A_local = tir.alloc_buffer((128, 128), scope="local")
848+ for i, j in tir.grid(128, 128):
849+ with tir.block([128, 128], "A_local") as [vi, vj]:
850+ A_local[vi, vj] = A[vi, vj]
851+ for i, j in tir.grid(128, 128):
852+ with tir.block([128, 128], "B") as [vi, vj]:
853+ B[vi, vj] = A_local[vi, vj] * 2.0
854+
855+ """
856+ return _ffi_api .ScheduleCacheRead ( # type: ignore # pylint: disable=no-member
857+ self , block , read_buffer_index , storage_scope
858+ )
859+
860+ def cache_write (self , block : BlockRV , write_buffer_index : int , storage_scope : str ) -> BlockRV :
861+ """Create a block that reads a buffer region into a write cache. It requires:
862+
863+ 1) There is only one block who write the buffer in the scope.
864+
865+ 2) The scope block have stage-pipeline property.
866+
867+ Parameters
868+ ----------
869+ block : BlockRV
870+ The producer block of the target buffer.
871+
872+ write_buffer_index: int
873+ The index of the buffer in block's write region.
874+
875+ storage_scope: str
876+ The target storage scope.
877+
878+
879+ Returns
880+ -------
881+ cached_block : BlockRV
882+ The block of the cache stage
883+
884+ Examples
885+ --------
886+ Before cache_write, in TensorIR, the IR is:
887+
888+ .. code-block:: python
889+
890+ @tvm.script.tir
891+ def before_cache_write(a: ty.handle, b: ty.handle) -> None:
892+ A = tir.match_buffer(a, (128, 128))
893+ B = tir.match_buffer(b, (128, 128))
894+ for i, j in tir.grid(128, 128):
895+ with tir.block([128, 128], "B") as [vi, vj]:
896+ B[vi, vj] = A[vi, vj] * 2.0
897+
898+ Create the schedule and cache_write:
899+
900+ .. code-block:: python
901+
902+ sch = tir.Schedule(before_cache_write)
903+ block_b = sch.get_block("B")
904+ sch.cache_write(block_b, 0, "local")
905+ print(tvm.script.asscript(sch.mod["main"]))
906+
907+ After applying cache_write, the IR becomes:
908+
909+ .. code-block:: python
910+
911+ @tvm.script.tir
912+ def after_cache_write(a: ty.handle, b: ty.handle) -> None:
913+ A = tir.match_buffer(a, (128, 128))
914+ B = tir.match_buffer(b, (128, 128))
915+ B_local = tir.alloc_buffer((128, 128), scope="local")
916+ for i, j in tir.grid(128, 128):
917+ with tir.block([128, 128], "A_local") as [vi, vj]:
918+ B_local[vi, vj] = A[vi, vj] * 2.0
919+ for i, j in tir.grid(128, 128):
920+ with tir.block([128, 128], "B") as [vi, vj]:
921+ B[vi, vj] = B_local[vi, vj]
922+
923+ """
924+ return _ffi_api .ScheduleCacheWrite ( # type: ignore # pylint: disable=no-member
925+ self , block , write_buffer_index , storage_scope
926+ )
927+
793928 ########## Schedule: Compute location ##########
794929
795930 def compute_inline (self , block : BlockRV ) -> None :
0 commit comments