@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
123123 return tvm .IRModule ({name : func })
124124
125125
126- def _wrap_as_prim_func_pass (flist , name ):
127- """Wrap flist as a function pass.
128-
129- This is an temporary adapter before we fully
130- migrate to the new pass manager.
131- """
132- def _transform (func , * _ ):
133- stmt = func .body
134- for f in flist :
135- stmt = f (stmt )
136- # create a new function with updated body.
137- return tvm .tir .PrimFunc (func .params ,
138- stmt ,
139- func .ret_type ,
140- func .buffer_map ,
141- func .attrs )
142- return tvm .tir .transform .prim_func_pass (_transform , opt_level = 0 , name = name )
143-
144-
145126def lower (sch ,
146127 args ,
147128 name = "main" ,
@@ -190,15 +171,15 @@ def lower(sch,
190171 else :
191172 mod = sch
192173
174+ pass_list = lower_phase0
193175 # Phase 1
194- pass_list = [
195- _wrap_as_prim_func_pass (lower_phase0 , "Custom-Phase0" ),
176+ pass_list += [
196177 tvm .tir .transform .InjectPrefetch (),
197178 tvm .tir .transform .StorageFlatten (64 , cfg .instrument_bound_checkers ),
198179 tvm .tir .transform .NarrowDataType (32 ),
199180 tvm .tir .transform .Simplify (),
200- _wrap_as_prim_func_pass (lower_phase1 , "Custom-Phase1" ),
201181 ]
182+ pass_list += lower_phase1
202183
203184 # Phase 2
204185 if not simple_mode :
@@ -214,8 +195,8 @@ def lower(sch,
214195 cfg .auto_unroll_max_depth ,
215196 cfg .auto_unroll_max_extent ,
216197 cfg .unroll_explicit ),
217- _wrap_as_prim_func_pass (lower_phase2 , "Custom-Phase2" ),
218198 ]
199+ pass_list += lower_phase2
219200
220201 # Phase 3
221202 pass_list += [
@@ -225,7 +206,7 @@ def lower(sch,
225206
226207 if not cfg .disable_select_rewriting :
227208 pass_list += [tvm .tir .transform .RewriteUnsafeSelect ()]
228- pass_list += [ _wrap_as_prim_func_pass ( lower_phase3 , "Custom-Phase3" )]
209+ pass_list += lower_phase3
229210
230211 # Instrument BoundCheckers
231212 if cfg .instrument_bound_checkers :
0 commit comments