@@ -156,6 +156,54 @@ def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None:
156156 )
157157
158158
159+ @tvm .script .tir
160+ def high_dim_opaque_access (a : ty .handle ) -> None :
161+ A = tir .match_buffer (a , (16 , 32 , 64 ))
162+ for i , j , k in tir .grid (16 , 2 , 4 ):
163+ with tir .block ([]):
164+ As_0 = tir .var ("int32" )
165+ As_1 = tir .var ("int32" )
166+ tir .reads ([])
167+ tir .writes (A [i , j * 16 : j * 16 + 16 , k * 16 : k * 16 + 16 ])
168+ sub_A = tir .match_buffer (
169+ A [i , j * 16 : j * 16 + 16 , k * 16 : k * 16 + 16 ],
170+ (16 , 16 ),
171+ strides = [As_0 , As_1 ],
172+ offset_factor = 1 ,
173+ )
174+ tir .evaluate (
175+ tir .intrin_test (
176+ sub_A .data ,
177+ sub_A .elem_offset ,
178+ sub_A .strides [0 ],
179+ sub_A .strides [1 ],
180+ sub_A .shape [0 ],
181+ sub_A .shape [1 ],
182+ dtype = "handle" ,
183+ )
184+ )
185+
186+
187+ @tvm .script .tir
188+ def transformed_high_dim_opaque_access (a : ty .handle ) -> None :
189+ A = tir .match_buffer (a , (16 , 32 , 64 ))
190+ for i , j , k in tir .grid (16 , 2 , 4 ):
191+ with tir .block ([]):
192+ tir .reads ([])
193+ tir .writes (A [i , j * 16 : j * 16 + 16 , k * 16 : k * 16 + 16 ])
194+ tir .evaluate (
195+ tir .intrin_test (
196+ A .data ,
197+ i * 2048 + j * 1024 + k * 16 ,
198+ 64 ,
199+ 1 ,
200+ 16 ,
201+ 16 ,
202+ dtype = "handle" ,
203+ )
204+ )
205+
206+
159207@tvm .script .tir
160208def recursive_match (a : ty .handle , b : ty .handle ) -> None :
161209 A = tir .match_buffer (a , (64 , 64 , 64 ))
@@ -419,6 +467,10 @@ def test_opaque_access():
419467 _check (opaque_access , transformed_opaque_access )
420468
421469
470+ def test_high_dim_opaque_access ():
471+ _check (high_dim_opaque_access , transformed_high_dim_opaque_access )
472+
473+
422474def test_recursive_match ():
423475 _check (recursive_match , transformed_recursive_match )
424476
@@ -447,6 +499,7 @@ def test_fail_match_func_param():
447499if __name__ == "__main__" :
448500 test_buffer_load_store ()
449501 test_opaque_access ()
502+ test_high_dim_opaque_access ()
450503 test_recursive_match ()
451504 test_symbolic_match ()
452505 test_rank0_buffer ()
0 commit comments