Skip to content

Commit 4627a68

Browse files
committed
[PTYTHON] Migrate VTA TIR passes to the new pass manager.
1 parent 72f2aea commit 4627a68

File tree

13 files changed

+1048
-1053
lines changed

13 files changed

+1048
-1053
lines changed

include/tvm/target/target.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <tvm/support/with.h>
2828
#include <tvm/node/container.h>
2929
#include <tvm/ir/expr.h>
30+
#include <tvm/ir/transform.h>
3031

3132
#include <string>
3233
#include <vector>
@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
225226
/*! \brief Whether to partition const loop */
226227
bool partition_const_loop = false;
227228

228-
/*! \brief Whether to dump the IR of each pass (only when building from python) */
229-
std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
229+
/*! \brief List of passes to be injected into the low-level pipeline. */
230+
std::vector<std::pair<int, transform::Pass>> add_lower_pass;
230231

231232
/*! \brief Whether to dump the IR of each pass (only when building from python) */
232233
bool dump_pass_ir = false;

python/tvm/autotvm/measure/measure_methods.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
615615
"""Verify the validity of a gpu kernel.
616616
This pass will check memory usage and number of threads per block.
617617
"""
618-
def verify_pass(stmt):
619-
valid = ir_pass.VerifyGPUCode(stmt, kwargs)
618+
def verify_pass(f, *_):
619+
valid = ir_pass.VerifyGPUCode(f.body, kwargs)
620620
if not valid:
621621
raise InstantiationError("Skipped because of invalid gpu kernel")
622-
return stmt
623-
return verify_pass
622+
return f
623+
return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)

python/tvm/driver/build_module.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,15 @@ def lower(sch,
190190
else:
191191
mod = sch
192192

193+
pass_list = lower_phase0
193194
# Phase 1
194-
pass_list = [
195-
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
195+
pass_list += [
196196
tvm.tir.transform.InjectPrefetch(),
197197
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
198198
tvm.tir.transform.NarrowDataType(32),
199199
tvm.tir.transform.Simplify(),
200-
_wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
201200
]
201+
pass_list += lower_phase1
202202

203203
# Phase 2
204204
if not simple_mode:
@@ -214,8 +214,8 @@ def lower(sch,
214214
cfg.auto_unroll_max_depth,
215215
cfg.auto_unroll_max_extent,
216216
cfg.unroll_explicit),
217-
_wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
218217
]
218+
pass_list += lower_phase2
219219

220220
# Phase 3
221221
pass_list += [
@@ -225,7 +225,7 @@ def lower(sch,
225225

226226
if not cfg.disable_select_rewriting:
227227
pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
228-
pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")]
228+
pass_list += lower_phase3
229229

230230
# Instrument BoundCheckers
231231
if cfg.instrument_bound_checkers:

python/tvm/tir/function.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,19 @@ def __init__(self,
6767

6868
self.__init_handle_by_constructor__(
6969
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
70+
71+
def with_body(self, new_body):
72+
"""Create a new PrimFunc with the same set signatures but a new body.
73+
74+
Parameters
75+
----------
76+
new_body : Stmt
77+
The new body.
78+
79+
Returns
80+
-------
81+
new_func : PrimFunc
82+
The created new function.
83+
"""
84+
return PrimFunc(
85+
self.params, new_body, self.ret_type, self.buffer_map, self.attrs)

src/target/target.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
434434
TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
435435
.set_body([](TVMArgs args, TVMRetValue* ret) {
436436
BuildConfig cfg = args[0];
437-
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
437+
std::vector<std::pair<int, transform::Pass>> add_lower_pass;
438438
CHECK_EQ(args.size() % 2, 1);
439439
for (int i = 1; i < args.size(); i += 2) {
440440
add_lower_pass.push_back(std::make_pair(
441441
args[i].operator int(),
442-
args[i + 1].operator tvm::runtime::PackedFunc()));
442+
args[i + 1].operator transform::Pass()));
443443
}
444444
cfg->add_lower_pass = add_lower_pass;
445445
});

tests/python/relay/test_pass_fold_constant.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,13 @@ def expected():
5151
z = relay.add(y, relay.const(c_data))
5252
return relay.Function([x], z)
5353

54-
def fail(x):
55-
raise RuntimeError()
54+
def FailPass():
55+
def _transform(m, *args):
56+
raise RuntimeError()
57+
return tvm.transform.module_pass(_transform, opt_level=0)
5658

5759
# the fold constant should work on any context.
58-
with tvm.target.build_config(add_lower_pass=[(0, fail)]):
60+
with tvm.target.build_config(add_lower_pass=[(0, FailPass())]):
5961
with tvm.target.create("cuda"):
6062
zz = run_opt_pass(before(), transform.FoldConstant())
6163
zexpected = run_opt_pass(expected(), transform.InferType())

tests/python/unittest/test_target_codegen_cuda.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_cuda_shuffle():
182182
sch[c].bind(xo, thrx)
183183
sch[c].vectorize(xi)
184184

185-
def my_vectorize(stmt):
185+
def MyVectorize():
186186
def vectorizer(op):
187187
if op.for_type == tvm.tir.For.Vectorized:
188188
four = tvm.tir.const(4, 'int32')
@@ -198,9 +198,13 @@ def vectorizer(op):
198198
new_b = tvm.tir.Shuffle(bs, ids)
199199
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
200200
return None
201-
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
202201

203-
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
202+
def _transform(f, *_):
203+
return f.with_body(
204+
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
205+
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
206+
207+
with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
204208
module = tvm.build(sch, [a, b, c], target='cuda')
205209
a_ = np.array(list(range(64)), dtype='int32')
206210
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')

tests/python/unittest/test_target_codegen_llvm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -671,8 +671,7 @@ def test_llvm_shuffle():
671671
c = te.compute((8, ), lambda x: a[x] + b[7-x])
672672
sch = te.create_schedule(c.op)
673673

674-
def my_vectorize(stmt):
675-
674+
def my_vectorize():
676675
def vectorizer(op):
677676
store = op.body
678677
idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8)
@@ -684,9 +683,13 @@ def vectorizer(op):
684683
value = new_a + new_b
685684
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
686685

687-
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
686+
def _transform(f, *_):
687+
return f.with_body(
688+
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
689+
690+
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
688691

689-
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
692+
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize())]):
690693
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
691694
module = tvm.build(sch, [a, b, c])
692695
a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))

tests/python/unittest/test_tir_pass_verify_gpu_code.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from tvm import te
2020

2121
def get_verify_pass(valid, **kwargs):
22-
def verify_pass(stmt):
23-
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(stmt, kwargs)
24-
return stmt
25-
return verify_pass
22+
def _fverify(f, *_):
23+
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(f.body, kwargs)
24+
return f
25+
return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0)
2626

2727
def test_shared_memory():
2828
def check_shared_memory(dtype):

tutorials/dev/low_level_custom_pass.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,20 @@ def vectorize8(op):
117117
return body
118118
return None
119119

120-
def vectorize(stmt):
120+
@tvm.tir.transform.prim_func_pass(opt_level=0)
121+
def vectorize(f, mod, ctx):
121122
global loops
122123

123-
tvm.tir.ir_pass.PostOrderVisit(stmt, find_width8)
124+
tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8)
124125

125126
if not loops:
126-
return stmt
127+
return sf
127128

128129
# The last list arugment indicates what kinds of nodes will be transformed.
129130
# Thus, in this case only `For` nodes will call `vectorize8`
130-
stmt = tvm.tir.ir_pass.IRTransform(stmt, None, vectorize8, ['For'])
131+
return f.with_body(
132+
tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For']))
131133

132-
return stmt
133134

134135
#####################################################################
135136
# Glue to Lowering

0 commit comments

Comments
 (0)