@@ -152,6 +152,39 @@ def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None:
152152 C [vi , vj ] = C [vi , vj ] + (A [vi , vk ] * B [vj , vk ])
153153
154154
155+ @T .prim_func
156+ def matmul_with_annotation (a : T .handle , b : T .handle , c : T .handle ) -> None :
157+ A = T .match_buffer (a , [128 , 128 ])
158+ B = T .match_buffer (b , [128 , 128 ])
159+ C = T .match_buffer (c , [128 , 128 ])
160+ for i , j , k in T .grid (128 , 128 , 128 ):
161+ with T .block ("update" ):
162+ T .block_attr ({"test_annotation" : 1 })
163+ vi , vj , vk = T .axis .remap ("SSR" , [i , j , k ])
164+ with T .init ():
165+ C [vi , vj ] = 0.0
166+ C [vi , vj ] = C [vi , vj ] + A [vi , vk ] * B [vj , vk ]
167+
168+
169+ @T .prim_func
170+ def matmul_decompose_with_annotation (a : T .handle , b : T .handle , c : T .handle ) -> None :
171+ A = T .match_buffer (a , [128 , 128 ])
172+ B = T .match_buffer (b , [128 , 128 ])
173+ C = T .match_buffer (c , [128 , 128 ])
174+
175+ for i , j in T .grid (128 , 128 ):
176+ with T .block ("init" ):
177+ T .block_attr ({"test_annotation" : 1 })
178+ vi , vj = T .axis .remap ("SS" , [i , j ])
179+ C [vi , vj ] = 0.0
180+
181+ for i , j , k in T .grid (128 , 128 , 128 ):
182+ with T .block ("update" ):
183+ T .block_attr ({"test_annotation" : 1 })
184+ vi , vj , vk = T .axis .remap ("SSR" , [i , j , k ])
185+ C [vi , vj ] = C [vi , vj ] + A [vi , vk ] * B [vj , vk ]
186+
187+
155188# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
156189
157190
@@ -201,5 +234,14 @@ def test_reduction_decompose4():
201234 verify_trace_roundtrip (s , mod = matmul )
202235
203236
237+ def test_reduction_decompose_with_annotation ():
238+ s = tir .Schedule (matmul_with_annotation , debug_mask = "all" )
239+ C = s .get_block ("update" )
240+ i , j , k = s .get_loops (C )
241+ s .decompose_reduction (C , i )
242+ tvm .ir .assert_structural_equal (matmul_decompose_with_annotation , s .mod ["main" ])
243+ verify_trace_roundtrip (s , mod = matmul_with_annotation )
244+
245+
204246if __name__ == "__main__" :
205247 sys .exit (pytest .main ([__file__ ] + sys .argv [1 :]))
0 commit comments