Skip to content

Commit 53fd712

Browse files
MasterJH5574tqchen
authored andcommitted
[MERGE-FIX] Update the code to fix merge issues
Fix FuseOps to adapt #15137 Fix TIR TVMScript to adapt #15214
1 parent 23edbff commit 53fd712

20 files changed

+315
-315
lines changed

src/relax/transform/fuse_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) {
998998

999999
// Step 2. Partition the graph by applying the fusion algorithm.
10001000
std::vector<GraphPartitioner::Group*> groups =
1001-
GraphPartitioner(&arena, opt_level, max_fuse_depth).Partition(graph);
1001+
GraphPartitioner(&arena, opt_level, max_fuse_depth, /*max_function_args=*/0).Partition(graph);
10021002

10031003
// Step 3. Transform the IRModule by fusing the operators in accordance with the graph partition
10041004
// results.

tests/python/relax/test_analysis_suggest_layout_transforms.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def apply_transformations(func, suggested_transfoms, print_transformation=False)
4242

4343

4444
def test_nested_blocks():
45-
@T.prim_func
45+
@T.prim_func(private=True)
4646
def nested_block(
4747
arg: T.Buffer((32, 64, 224, 224), "float32"),
4848
relu: T.Buffer((32, 64, 224, 224), "float32"),
@@ -67,7 +67,7 @@ def nested_block(
6767

6868

6969
def test_mismatch_transformations_and_num_params():
70-
@T.prim_func
70+
@T.prim_func(private=True)
7171
def elemwise(
7272
arg: T.Buffer((32, 64, 224, 224), "float32"),
7373
relu: T.Buffer((32, 64, 224, 224), "float32"),
@@ -91,7 +91,7 @@ def elemwise(
9191

9292

9393
def test_empty_write_transformations():
94-
@T.prim_func
94+
@T.prim_func(private=True)
9595
def elemwise(
9696
arg: T.Buffer((32, 64, 224, 224), "float32"),
9797
relu: T.Buffer((32, 64, 224, 224), "float32"),
@@ -110,7 +110,7 @@ def elemwise(
110110

111111

112112
def test_non_bijective_block_transform():
113-
@T.prim_func
113+
@T.prim_func(private=True)
114114
def before(
115115
arg: T.Buffer((32, 64), "float32"),
116116
output: T.Buffer((32, 64), "float32"),
@@ -129,7 +129,7 @@ def before(
129129

130130

131131
def test_non_affine_access():
132-
@T.prim_func
132+
@T.prim_func(private=True)
133133
def before(
134134
arg: T.Buffer((32, 64), "float32"),
135135
output: T.Buffer((32 * 64, 10), "float32"),
@@ -148,7 +148,7 @@ def before(
148148

149149

150150
def test_unsupported_write_spatial_layout():
151-
@T.prim_func
151+
@T.prim_func(private=True)
152152
def before(
153153
arg: T.Buffer((4, 4), "float32"),
154154
output: T.Buffer((16), "float32"),
@@ -167,7 +167,7 @@ def before(
167167

168168

169169
def test_unpacked_iter_used_in_read_access():
170-
@T.prim_func
170+
@T.prim_func(private=True)
171171
def before(
172172
arg: T.Buffer((8, 4), "float32"),
173173
output: T.Buffer((4, 8), "float32"),
@@ -179,7 +179,7 @@ def before(
179179
T.writes(output[v_ax0, v_ax1])
180180
output[v_ax0, v_ax1] = arg[v_ax1, v_ax2]
181181

182-
@T.prim_func
182+
@T.prim_func(private=True)
183183
def expected(
184184
arg: T.Buffer((8, 4), "float32"),
185185
output: T.Buffer((32), "float32"),
@@ -199,7 +199,7 @@ def expected(
199199

200200

201201
def test_invalid_index_map():
202-
@T.prim_func
202+
@T.prim_func(private=True)
203203
def elemwise(
204204
arg: T.Buffer((32, 64, 224, 224), "float32"),
205205
relu: T.Buffer((32, 64, 224, 224), "float32"),
@@ -220,7 +220,7 @@ def elemwise(
220220

221221

222222
def test_SRSR_block():
223-
@T.prim_func
223+
@T.prim_func(private=True)
224224
def before(
225225
arg: T.Buffer((32, 224, 64, 224), "float32"),
226226
sum: T.Buffer((32, 64), "float32"),
@@ -234,7 +234,7 @@ def before(
234234
sum[v_ax0, v_ax1] = T.float32(0)
235235
sum[v_ax0, v_ax1] = sum[v_ax0, v_ax1] + arg[v_ax0, v_k2, v_ax1, v_k3]
236236

237-
@T.prim_func
237+
@T.prim_func(private=True)
238238
def expected(
239239
arg: T.Buffer((32, 224, 16, 224, 4), "float32"),
240240
sum: T.Buffer((32, 16, 4), "float32"),
@@ -256,7 +256,7 @@ def expected(
256256

257257

258258
def test_op_elemwise_symbolic():
259-
@T.prim_func
259+
@T.prim_func(private=True)
260260
def before(arg: T.handle, relu: T.handle):
261261
N = T.int64()
262262
C = T.int64()
@@ -271,7 +271,7 @@ def before(arg: T.handle, relu: T.handle):
271271
T.writes(Relu[v_i0, v_i1, v_i2, v_i3])
272272
Relu[v_i0, v_i1, v_i2, v_i3] = T.max(Arg[v_i0, v_i1, v_i2, v_i3], T.float32(0))
273273

274-
@T.prim_func
274+
@T.prim_func(private=True)
275275
def expected(arg: T.handle, relu: T.handle):
276276
N = T.int64()
277277
C = T.int64()
@@ -295,7 +295,7 @@ def expected(arg: T.handle, relu: T.handle):
295295

296296

297297
def test_op_elemwise():
298-
@T.prim_func
298+
@T.prim_func(private=True)
299299
def before(
300300
arg: T.Buffer((32, 64, 224, 224), "float32"),
301301
relu: T.Buffer((32, 64, 224, 224), "float32"),
@@ -307,7 +307,7 @@ def before(
307307
T.writes(relu[v_i0, v_i1, v_i2, v_i3])
308308
relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0))
309309

310-
@T.prim_func
310+
@T.prim_func(private=True)
311311
def expected(
312312
arg: T.Buffer((32, 224, 224, 64), "float32"),
313313
relu: T.Buffer((32, 224, 224, 64), "float32"),
@@ -327,7 +327,7 @@ def expected(
327327

328328

329329
def test_op_pool_nchw_nhwc():
330-
@T.prim_func
330+
@T.prim_func(private=True)
331331
def before(
332332
arg: T.Buffer((32, 64, 224, 224), "float32"),
333333
pool_max: T.Buffer((32, 64, 111, 223), "float32"),
@@ -359,7 +359,7 @@ def before(
359359
],
360360
)
361361

362-
@T.prim_func
362+
@T.prim_func(private=True)
363363
def expected(
364364
arg: T.Buffer((32, 224, 224, 64), "float32"),
365365
pool_max: T.Buffer((32, 111, 223, 64), "float32"),
@@ -387,7 +387,7 @@ def expected(
387387

388388

389389
def test_op_pool_nchw16c_nhwc():
390-
@T.prim_func
390+
@T.prim_func(private=True)
391391
def before(
392392
arg: T.Buffer(
393393
(32, 4, 224, 224, 16),
@@ -413,7 +413,7 @@ def before(
413413
arg[v_ax0, v_ax1, v_ax2 * 2 + v_rv0, v_ax3 + v_rv1, v_ax4],
414414
)
415415

416-
@T.prim_func
416+
@T.prim_func(private=True)
417417
def expected(
418418
arg: T.Buffer((32, 224, 224, 64), "float32"),
419419
pool_max: T.Buffer((32, 110, 220, 64), "float32"),
@@ -440,7 +440,7 @@ def expected(
440440

441441

442442
def test_op_reduce():
443-
@T.prim_func
443+
@T.prim_func(private=True)
444444
def before(
445445
arg: T.Buffer((32, 64, 224, 224), "float32"),
446446
sum: T.Buffer((32, 64), "float32"),
@@ -454,7 +454,7 @@ def before(
454454
sum[v_ax0, v_ax1] = T.float32(0)
455455
sum[v_ax0, v_ax1] = sum[v_ax0, v_ax1] + arg[v_ax0, v_ax1, v_k2, v_k3]
456456

457-
@T.prim_func
457+
@T.prim_func(private=True)
458458
def expected(
459459
arg: T.Buffer((32, 4, 224, 224, 16), "float32"),
460460
sum: T.Buffer((32, 4, 16), "float32"),
@@ -477,7 +477,7 @@ def expected(
477477

478478
def test_op_upsampling():
479479
# relay materializes the layout if H, W or D dimensions are moved or tiled.
480-
@T.prim_func
480+
@T.prim_func(private=True)
481481
def before(
482482
arg: T.Buffer((32, 64, 224, 224), "float32"),
483483
resize: T.Buffer((32, 64, 202, 246), "float32"),
@@ -518,7 +518,7 @@ def before(
518518
),
519519
]
520520

521-
@T.prim_func
521+
@T.prim_func(private=True)
522522
def expected(
523523
arg: T.Buffer((32, 64, 224, 224), "float32"),
524524
resize: T.Buffer((32, 202, 246, 64), "float32"),
@@ -568,7 +568,7 @@ def expected(
568568

569569

570570
def test_op_strided_slice():
571-
@T.prim_func
571+
@T.prim_func(private=True)
572572
def before(
573573
arg: T.Buffer((32, 64, 224, 224), "float32"),
574574
T_strided_slice_with_axes: T.Buffer((32, 64, 10, 8), "float32"),
@@ -592,7 +592,7 @@ def before(
592592
v_ax3 * 7 + 4,
593593
]
594594

595-
@T.prim_func
595+
@T.prim_func(private=True)
596596
def expected(
597597
arg: T.Buffer((32, 224, 224, 16, 4), "float32"),
598598
T_strided_slice_with_axes: T.Buffer((32, 10, 8, 16, 4), "float32"),
@@ -615,7 +615,7 @@ def expected(
615615

616616

617617
def test_op_binary_broadcast():
618-
@T.prim_func
618+
@T.prim_func(private=True)
619619
def before(
620620
arg0: T.Buffer((32, 64, 224, 224), "float32"),
621621
arg1: T.Buffer((64, 224, 224), "float32"),
@@ -635,7 +635,7 @@ def before(
635635
arg0[v_ax0, v_ax1, v_ax2, v_ax3] + arg1[v_ax1, v_ax2, v_ax3]
636636
)
637637

638-
@T.prim_func
638+
@T.prim_func(private=True)
639639
def expected(
640640
arg0: T.Buffer((32, 224, 224, 16, 4), "float32"),
641641
arg1: T.Buffer((224, 224, 16, 4), "float32"),
@@ -658,7 +658,7 @@ def expected(
658658

659659

660660
def test_op_transpose():
661-
@T.prim_func
661+
@T.prim_func(private=True)
662662
def before(
663663
arg: T.Buffer((32, 64, 224, 224), "float32"),
664664
T_transpose: T.Buffer((32, 224, 224, 64), "float32"),
@@ -670,7 +670,7 @@ def before(
670670
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
671671
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax3, v_ax1, v_ax2]
672672

673-
@T.prim_func
673+
@T.prim_func(private=True)
674674
def expected(
675675
arg: T.Buffer((32, 64, 224, 224), "float32"),
676676
T_transpose: T.Buffer((32, 224, 64, 224), "float32"),
@@ -690,7 +690,7 @@ def expected(
690690

691691

692692
def test_op_pad():
693-
@T.prim_func
693+
@T.prim_func(private=True)
694694
def before(
695695
arg: T.Buffer((32, 64, 224, 224), "float32"),
696696
PadInput: T.Buffer((32, 64, 230, 230), "float32"),
@@ -706,7 +706,7 @@ def before(
706706
T.float32(2),
707707
)
708708

709-
@T.prim_func
709+
@T.prim_func(private=True)
710710
def expected(
711711
arg: T.Buffer((32, 224, 224, 16, 4), "float32"),
712712
PadInput: T.Buffer((32, 230, 230, 16, 4), "float32"),
@@ -730,7 +730,7 @@ def expected(
730730

731731

732732
def test_op_split():
733-
@T.prim_func
733+
@T.prim_func(private=True)
734734
def before(
735735
arg: T.Buffer((32, 64, 224, 224), "float32"),
736736
split0: T.Buffer((32, 32, 224, 224), "float32"),
@@ -749,7 +749,7 @@ def before(
749749
T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3])
750750
split1[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3]
751751

752-
@T.prim_func
752+
@T.prim_func(private=True)
753753
def expected(
754754
arg: T.Buffer((32, 224, 224, 64), "float32"),
755755
split0: T.Buffer((32, 224, 224, 32), "float32"),
@@ -778,7 +778,7 @@ def expected(
778778

779779
@pytest.mark.skip("temp disable, due to minor arith regression")
780780
def test_op_split_tiling_split_dim():
781-
@T.prim_func
781+
@T.prim_func(private=True)
782782
def before(
783783
arg: T.Buffer((32, 64, 224, 224), "float32"),
784784
split0: T.Buffer((32, 32, 224, 224), "float32"),
@@ -797,7 +797,7 @@ def before(
797797
T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3])
798798
split1[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3]
799799

800-
@T.prim_func
800+
@T.prim_func(private=True)
801801
def expected(
802802
arg: T.Buffer((32, 224, 224, 16, 4), "float32"),
803803
split0: T.Buffer((32, 224, 224, 8, 4), "float32"),

tests/python/relax/test_backend_transform_shape_lower.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def main(
189189

190190
@tvm.script.ir_module
191191
class Expected:
192-
@T.prim_func
192+
@T.prim_func(private=True)
193193
def shape_func(H: T.Buffer(T.int64(4), "int64")):
194194
# generated compute function
195195
T.func_attr({"tir.is_host_func": 1})

tests/python/relax/test_blockbuilder_emit_te.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def te_func(A, offset):
4141

4242
@I.ir_module
4343
class Expected:
44-
@T.prim_func
44+
@T.prim_func(private=True)
4545
def te_func(
4646
A: T.Buffer((T.int64(10),), "float32"),
4747
B: T.Buffer((T.int64(10),), "float32"),

tests/python/relax/test_frontend_nn_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def test(self, x: Tensor):
218218
# fmt: off
219219
@I.ir_module
220220
class Expected:
221-
@T.prim_func
221+
@T.prim_func(private=True)
222222
def add_one(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_add: T.Buffer((T.int64(10), T.int64(10)), "float32")):
223223
T.func_attr({"tir.noalias": T.bool(True)})
224224
# with T.block("root"):

tests/python/relax/test_meta_schedule_relax_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def main(data: R.Tensor((1, 8, 8, 4), dtype="int32")) -> R.Tensor((1, 8, 8, 4),
5454
# fmt: off
5555
@I.ir_module
5656
class Module:
57-
@T.prim_func
57+
@T.prim_func(private=True)
5858
def conv2d(rxplaceholder: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), DepthwiseConv2d: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")):
5959
T.func_attr({"op_pattern": 4, "tir.noalias": True})
6060
# with T.block("root"):
@@ -76,7 +76,7 @@ def conv2d(rxplaceholder: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(
7676
DepthwiseConv2d[v_b, v_i, v_j, v_c] = 0
7777
DepthwiseConv2d[v_b, v_i, v_j, v_c] = DepthwiseConv2d[v_b, v_i, v_j, v_c] + PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c] * fused_constant_1[v_di, v_dj, v_c, T.int64(0)]
7878

79-
@T.prim_func
79+
@T.prim_func(private=True)
8080
def conv2d0(rxplaceholder0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), DepthwiseConv2d0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")):
8181
T.func_attr({"op_pattern": 4, "tir.noalias": True})
8282
# with T.block("root"):
@@ -98,7 +98,7 @@ def conv2d0(rxplaceholder0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int6
9898
DepthwiseConv2d0[v_b, v_i, v_j, v_c] = 0
9999
DepthwiseConv2d0[v_b, v_i, v_j, v_c] = DepthwiseConv2d0[v_b, v_i, v_j, v_c] + PaddedInput0[v_b, v_i + v_di, v_j + v_dj, v_c] * fused_constant0_1[v_di, v_dj, v_c, T.int64(0)]
100100

101-
@T.prim_func
101+
@T.prim_func(private=True)
102102
def fused_conv2d_add(data: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), T_add: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")):
103103
T.func_attr({"tir.noalias": True})
104104
# with T.block("root"):

0 commit comments

Comments
 (0)