@@ -42,7 +42,7 @@ def apply_transformations(func, suggested_transfoms, print_transformation=False)
4242
4343
4444def test_nested_blocks ():
45- @T .prim_func
45+ @T .prim_func ( private = True )
4646 def nested_block (
4747 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
4848 relu : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
@@ -67,7 +67,7 @@ def nested_block(
6767
6868
6969def test_mismatch_transformations_and_num_params ():
70- @T .prim_func
70+ @T .prim_func ( private = True )
7171 def elemwise (
7272 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
7373 relu : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
@@ -91,7 +91,7 @@ def elemwise(
9191
9292
9393def test_empty_write_transformations ():
94- @T .prim_func
94+ @T .prim_func ( private = True )
9595 def elemwise (
9696 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
9797 relu : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
@@ -110,7 +110,7 @@ def elemwise(
110110
111111
112112def test_non_bijective_block_transform ():
113- @T .prim_func
113+ @T .prim_func ( private = True )
114114 def before (
115115 arg : T .Buffer ((32 , 64 ), "float32" ),
116116 output : T .Buffer ((32 , 64 ), "float32" ),
@@ -129,7 +129,7 @@ def before(
129129
130130
131131def test_non_affine_access ():
132- @T .prim_func
132+ @T .prim_func ( private = True )
133133 def before (
134134 arg : T .Buffer ((32 , 64 ), "float32" ),
135135 output : T .Buffer ((32 * 64 , 10 ), "float32" ),
@@ -148,7 +148,7 @@ def before(
148148
149149
150150def test_unsupported_write_spatial_layout ():
151- @T .prim_func
151+ @T .prim_func ( private = True )
152152 def before (
153153 arg : T .Buffer ((4 , 4 ), "float32" ),
154154 output : T .Buffer ((16 ), "float32" ),
@@ -167,7 +167,7 @@ def before(
167167
168168
169169def test_unpacked_iter_used_in_read_access ():
170- @T .prim_func
170+ @T .prim_func ( private = True )
171171 def before (
172172 arg : T .Buffer ((8 , 4 ), "float32" ),
173173 output : T .Buffer ((4 , 8 ), "float32" ),
@@ -179,7 +179,7 @@ def before(
179179 T .writes (output [v_ax0 , v_ax1 ])
180180 output [v_ax0 , v_ax1 ] = arg [v_ax1 , v_ax2 ]
181181
182- @T .prim_func
182+ @T .prim_func ( private = True )
183183 def expected (
184184 arg : T .Buffer ((8 , 4 ), "float32" ),
185185 output : T .Buffer ((32 ), "float32" ),
@@ -199,7 +199,7 @@ def expected(
199199
200200
201201def test_invalid_index_map ():
202- @T .prim_func
202+ @T .prim_func ( private = True )
203203 def elemwise (
204204 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
205205 relu : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
@@ -220,7 +220,7 @@ def elemwise(
220220
221221
222222def test_SRSR_block ():
223- @T .prim_func
223+ @T .prim_func ( private = True )
224224 def before (
225225 arg : T .Buffer ((32 , 224 , 64 , 224 ), "float32" ),
226226 sum : T .Buffer ((32 , 64 ), "float32" ),
@@ -234,7 +234,7 @@ def before(
234234 sum [v_ax0 , v_ax1 ] = T .float32 (0 )
235235 sum [v_ax0 , v_ax1 ] = sum [v_ax0 , v_ax1 ] + arg [v_ax0 , v_k2 , v_ax1 , v_k3 ]
236236
237- @T .prim_func
237+ @T .prim_func ( private = True )
238238 def expected (
239239 arg : T .Buffer ((32 , 224 , 16 , 224 , 4 ), "float32" ),
240240 sum : T .Buffer ((32 , 16 , 4 ), "float32" ),
@@ -256,7 +256,7 @@ def expected(
256256
257257
258258def test_op_elemwise_symbolic ():
259- @T .prim_func
259+ @T .prim_func ( private = True )
260260 def before (arg : T .handle , relu : T .handle ):
261261 N = T .int64 ()
262262 C = T .int64 ()
@@ -271,7 +271,7 @@ def before(arg: T.handle, relu: T.handle):
271271 T .writes (Relu [v_i0 , v_i1 , v_i2 , v_i3 ])
272272 Relu [v_i0 , v_i1 , v_i2 , v_i3 ] = T .max (Arg [v_i0 , v_i1 , v_i2 , v_i3 ], T .float32 (0 ))
273273
274- @T .prim_func
274+ @T .prim_func ( private = True )
275275 def expected (arg : T .handle , relu : T .handle ):
276276 N = T .int64 ()
277277 C = T .int64 ()
@@ -295,7 +295,7 @@ def expected(arg: T.handle, relu: T.handle):
295295
296296
297297def test_op_elemwise ():
298- @T .prim_func
298+ @T .prim_func ( private = True )
299299 def before (
300300 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
301301 relu : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
@@ -307,7 +307,7 @@ def before(
307307 T .writes (relu [v_i0 , v_i1 , v_i2 , v_i3 ])
308308 relu [v_i0 , v_i1 , v_i2 , v_i3 ] = T .max (arg [v_i0 , v_i1 , v_i2 , v_i3 ], T .float32 (0 ))
309309
310- @T .prim_func
310+ @T .prim_func ( private = True )
311311 def expected (
312312 arg : T .Buffer ((32 , 224 , 224 , 64 ), "float32" ),
313313 relu : T .Buffer ((32 , 224 , 224 , 64 ), "float32" ),
@@ -327,7 +327,7 @@ def expected(
327327
328328
329329def test_op_pool_nchw_nhwc ():
330- @T .prim_func
330+ @T .prim_func ( private = True )
331331 def before (
332332 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
333333 pool_max : T .Buffer ((32 , 64 , 111 , 223 ), "float32" ),
@@ -359,7 +359,7 @@ def before(
359359 ],
360360 )
361361
362- @T .prim_func
362+ @T .prim_func ( private = True )
363363 def expected (
364364 arg : T .Buffer ((32 , 224 , 224 , 64 ), "float32" ),
365365 pool_max : T .Buffer ((32 , 111 , 223 , 64 ), "float32" ),
@@ -387,7 +387,7 @@ def expected(
387387
388388
389389def test_op_pool_nchw16c_nhwc ():
390- @T .prim_func
390+ @T .prim_func ( private = True )
391391 def before (
392392 arg : T .Buffer (
393393 (32 , 4 , 224 , 224 , 16 ),
@@ -413,7 +413,7 @@ def before(
413413 arg [v_ax0 , v_ax1 , v_ax2 * 2 + v_rv0 , v_ax3 + v_rv1 , v_ax4 ],
414414 )
415415
416- @T .prim_func
416+ @T .prim_func ( private = True )
417417 def expected (
418418 arg : T .Buffer ((32 , 224 , 224 , 64 ), "float32" ),
419419 pool_max : T .Buffer ((32 , 110 , 220 , 64 ), "float32" ),
@@ -440,7 +440,7 @@ def expected(
440440
441441
442442def test_op_reduce ():
443- @T .prim_func
443+ @T .prim_func ( private = True )
444444 def before (
445445 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
446446 sum : T .Buffer ((32 , 64 ), "float32" ),
@@ -454,7 +454,7 @@ def before(
454454 sum [v_ax0 , v_ax1 ] = T .float32 (0 )
455455 sum [v_ax0 , v_ax1 ] = sum [v_ax0 , v_ax1 ] + arg [v_ax0 , v_ax1 , v_k2 , v_k3 ]
456456
457- @T .prim_func
457+ @T .prim_func ( private = True )
458458 def expected (
459459 arg : T .Buffer ((32 , 4 , 224 , 224 , 16 ), "float32" ),
460460 sum : T .Buffer ((32 , 4 , 16 ), "float32" ),
@@ -477,7 +477,7 @@ def expected(
477477
478478def test_op_upsampling ():
479479 # relay materializes the layout if H, W or D dimensions are moved or tiled.
480- @T .prim_func
480+ @T .prim_func ( private = True )
481481 def before (
482482 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
483483 resize : T .Buffer ((32 , 64 , 202 , 246 ), "float32" ),
@@ -518,7 +518,7 @@ def before(
518518 ),
519519 ]
520520
521- @T .prim_func
521+ @T .prim_func ( private = True )
522522 def expected (
523523 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
524524 resize : T .Buffer ((32 , 202 , 246 , 64 ), "float32" ),
@@ -568,7 +568,7 @@ def expected(
568568
569569
570570def test_op_strided_slice ():
571- @T .prim_func
571+ @T .prim_func ( private = True )
572572 def before (
573573 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
574574 T_strided_slice_with_axes : T .Buffer ((32 , 64 , 10 , 8 ), "float32" ),
@@ -592,7 +592,7 @@ def before(
592592 v_ax3 * 7 + 4 ,
593593 ]
594594
595- @T .prim_func
595+ @T .prim_func ( private = True )
596596 def expected (
597597 arg : T .Buffer ((32 , 224 , 224 , 16 , 4 ), "float32" ),
598598 T_strided_slice_with_axes : T .Buffer ((32 , 10 , 8 , 16 , 4 ), "float32" ),
@@ -615,7 +615,7 @@ def expected(
615615
616616
617617def test_op_binary_broadcast ():
618- @T .prim_func
618+ @T .prim_func ( private = True )
619619 def before (
620620 arg0 : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
621621 arg1 : T .Buffer ((64 , 224 , 224 ), "float32" ),
@@ -635,7 +635,7 @@ def before(
635635 arg0 [v_ax0 , v_ax1 , v_ax2 , v_ax3 ] + arg1 [v_ax1 , v_ax2 , v_ax3 ]
636636 )
637637
638- @T .prim_func
638+ @T .prim_func ( private = True )
639639 def expected (
640640 arg0 : T .Buffer ((32 , 224 , 224 , 16 , 4 ), "float32" ),
641641 arg1 : T .Buffer ((224 , 224 , 16 , 4 ), "float32" ),
@@ -658,7 +658,7 @@ def expected(
658658
659659
660660def test_op_transpose ():
661- @T .prim_func
661+ @T .prim_func ( private = True )
662662 def before (
663663 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
664664 T_transpose : T .Buffer ((32 , 224 , 224 , 64 ), "float32" ),
@@ -670,7 +670,7 @@ def before(
670670 T .writes (T_transpose [v_ax0 , v_ax1 , v_ax2 , v_ax3 ])
671671 T_transpose [v_ax0 , v_ax1 , v_ax2 , v_ax3 ] = arg [v_ax0 , v_ax3 , v_ax1 , v_ax2 ]
672672
673- @T .prim_func
673+ @T .prim_func ( private = True )
674674 def expected (
675675 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
676676 T_transpose : T .Buffer ((32 , 224 , 64 , 224 ), "float32" ),
@@ -690,7 +690,7 @@ def expected(
690690
691691
692692def test_op_pad ():
693- @T .prim_func
693+ @T .prim_func ( private = True )
694694 def before (
695695 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
696696 PadInput : T .Buffer ((32 , 64 , 230 , 230 ), "float32" ),
@@ -706,7 +706,7 @@ def before(
706706 T .float32 (2 ),
707707 )
708708
709- @T .prim_func
709+ @T .prim_func ( private = True )
710710 def expected (
711711 arg : T .Buffer ((32 , 224 , 224 , 16 , 4 ), "float32" ),
712712 PadInput : T .Buffer ((32 , 230 , 230 , 16 , 4 ), "float32" ),
@@ -730,7 +730,7 @@ def expected(
730730
731731
732732def test_op_split ():
733- @T .prim_func
733+ @T .prim_func ( private = True )
734734 def before (
735735 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
736736 split0 : T .Buffer ((32 , 32 , 224 , 224 ), "float32" ),
@@ -749,7 +749,7 @@ def before(
749749 T .writes (split1 [v_ax0 , v_ax1 , v_ax2 , v_ax3 ])
750750 split1 [v_ax0 , v_ax1 , v_ax2 , v_ax3 ] = arg [v_ax0 , v_ax1 + 32 , v_ax2 , v_ax3 ]
751751
752- @T .prim_func
752+ @T .prim_func ( private = True )
753753 def expected (
754754 arg : T .Buffer ((32 , 224 , 224 , 64 ), "float32" ),
755755 split0 : T .Buffer ((32 , 224 , 224 , 32 ), "float32" ),
@@ -778,7 +778,7 @@ def expected(
778778
779779@pytest .mark .skip ("temp disable, due to minor arith regression" )
780780def test_op_split_tiling_split_dim ():
781- @T .prim_func
781+ @T .prim_func ( private = True )
782782 def before (
783783 arg : T .Buffer ((32 , 64 , 224 , 224 ), "float32" ),
784784 split0 : T .Buffer ((32 , 32 , 224 , 224 ), "float32" ),
@@ -797,7 +797,7 @@ def before(
797797 T .writes (split1 [v_ax0 , v_ax1 , v_ax2 , v_ax3 ])
798798 split1 [v_ax0 , v_ax1 , v_ax2 , v_ax3 ] = arg [v_ax0 , v_ax1 + 32 , v_ax2 , v_ax3 ]
799799
800- @T .prim_func
800+ @T .prim_func ( private = True )
801801 def expected (
802802 arg : T .Buffer ((32 , 224 , 224 , 16 , 4 ), "float32" ),
803803 split0 : T .Buffer ((32 , 224 , 224 , 8 , 4 ), "float32" ),
0 commit comments