@@ -2138,6 +2138,7 @@ def transform_layout(
2138
2138
--------
2139
2139
Before transform_layout, in TensorIR, the IR is:
2140
2140
.. code-block:: python
2141
+
2141
2142
@T.prim_func
2142
2143
def before_transform_layout(a: T.handle, c: T.handle) -> None:
2143
2144
A = T.match_buffer(a, (128, 128), "float32")
@@ -2151,14 +2152,18 @@ def before_transform_layout(a: T.handle, c: T.handle) -> None:
2151
2152
with T.block("C"):
2152
2153
vi, vj = T.axis.remap("SS", [i, j])
2153
2154
C[vi, vj] = B[vi, vj] + 1.0
2155
+
2154
2156
Create the schedule and do transform_layout:
2155
2157
.. code-block:: python
2158
+
2156
2159
sch = tir.Schedule(before_storage_align)
2157
2160
sch.transform_layout(sch.get_block("B"), buffer_index=0, is_write_index=True,
2158
2161
index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16))
2159
2162
print(sch.mod["main"].script())
2163
+
2160
2164
After applying transform_layout, the IR becomes:
2161
2165
.. code-block:: python
2166
+
2162
2167
@T.prim_func
2163
2168
def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None:
2164
2169
A = T.match_buffer(a, (128, 128), "float32")
@@ -2172,6 +2177,7 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
2172
2177
with T.block("C"):
2173
2178
vi, vj = T.axis.remap("SS", [i, j])
2174
2179
C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0
2180
+
2175
2181
"""
2176
2182
if callable (index_map ):
2177
2183
index_map = IndexMap .from_func (index_map )
0 commit comments