Skip to content

Commit fbc7cd9

Browse files
tqchentrevor-m
authored andcommitted
[TIR][REFACTOR] Migrate low-level passes in tvm.lower to the Unified IR pass manager. (apache#5364)
- Migrate BoundCheckers and Simplify - Migrate RewriteUnsafeSelect and RemoveNoOp - Migrate UnrollLoop and StorageRewrite - Migrate InjectDoubleBuffer and InjectVirtualThread - Migrate LoopPartition and Vectorize - Migrate CoProcSync, LiftAttrScope, InjectCopyIntrin We still keep ir_pass registerations for now. Need a separate PR to refactor the parts before the StorageFlatten.
1 parent 22748ba commit fbc7cd9

37 files changed

+1026
-458
lines changed

include/tvm/tir/analysis.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ struct ExprDeepEqual {
5353
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
5454
};
5555

56-
5756
/*!
5857
* \brief Find undefined vars in the statment.
5958
* \param stmt The function to be checked.

include/tvm/tir/ir_pass.h

Lines changed: 0 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -202,144 +202,13 @@ Stmt RewriteForTensorCore(Stmt stmt,
202202
*/
203203
bool VerifyCompactBuffer(Stmt stmt);
204204

205-
/*!
206-
* \brief Remove No Op from the Stmt.
207-
* \param stmt The stmt to be trasnformed
208-
* \return Transformed stmt.
209-
*/
210-
Stmt RemoveNoOp(Stmt stmt);
211-
212-
/*!
213-
* \brief unroll the constant loop marked by unroll.
214-
* This pass also automatically attach pragma unroll tag to loops which meets the standard.
215-
*
216-
* \param stmt The statment to be unrolled.
217-
* \param auto_max_step The maximum step before stop attach automatic unroll
218-
* \param auto_max_depth The maximum depth before stop attach automatic unroll
219-
* \param auto_max_extent The maximum extent of the loop we can unroll,
220-
* this is an legacy option that do not take the loop total steps into account.
221-
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
222-
* \return Transformed stmt.
223-
*/
224-
Stmt UnrollLoop(Stmt stmt,
225-
int auto_max_step,
226-
int auto_max_depth,
227-
int auto_max_extent,
228-
bool explicit_unroll);
229-
230-
/*!
231-
* \brief vectorize the constant loops
232-
* \param stmt The statement to be vectorized.
233-
* \return Transformed stmt.
234-
*/
235-
Stmt VectorizeLoop(Stmt stmt);
236-
237-
/*!
238-
* \brief convert vectorized loops into serialized loops
239-
* \param stmt The statement to skip vectorization on.
240-
* \return Transformed stmt.
241-
*/
242-
Stmt SkipVectorize(Stmt stmt);
243-
244-
/*!
245-
* \brief instruments bound checkers.
246-
* \param stmt The statement to be instrumented.
247-
* \return Instrumented stmt.
248-
*/
249-
Stmt InstrumentBoundCheckers(Stmt stmt);
250-
251-
/*!
252-
* \brief Inject virtual thread loops into stmt.
253-
* \param stmt The statement to be transformed.
254-
* \return Transformed stmt.
255-
*/
256-
Stmt InjectVirtualThread(Stmt stmt);
257-
258205
/*!
259206
* \brief Inject prefetch instructions into stmt.
260207
* \param stmt The statement to be transformed.
261208
* \return Transformed stmt.
262209
*/
263210
Stmt InjectPrefetch(Stmt stmt);
264211

265-
/*!
266-
* \brief Inject double buffer into stmt.
267-
* \param stmt The statement to be transformed.
268-
* \param split_loop Loop splitting factor.
269-
* \return Transformed stmt.
270-
*/
271-
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
272-
273-
/*!
274-
* \brief Inject copy intrinsics with optional pad.
275-
*
276-
* \param stmt The statement to be transformed.
277-
* \param pragma_key The pragma key for hint of copy.
278-
* \param fintrin The function with signature
279-
*
280-
* Stmt fintrin(Buffer src,
281-
* Buffer dst,
282-
* Array<Expr> pad_before,
283-
* Array<Expr> pad_after,
284-
* Expr pad_value)
285-
* \return Transformed stmt.
286-
*/
287-
Stmt InjectCopyIntrin(Stmt stmt,
288-
const std::string& pragma_key,
289-
const runtime::PackedFunc& fintrin);
290-
291-
/*!
292-
* \brief Rewrite storage allocation pattern.
293-
* Moves the allocation to outer most possible scope.
294-
* Trying to share space between allocations to make
295-
* a static allocation plan when possible.
296-
*
297-
* \param stmt The stmt to be transformed
298-
* \return Transformed stmt.
299-
*/
300-
Stmt StorageRewrite(Stmt stmt);
301-
302-
/*!
303-
* \brief partition loops in the stmt
304-
* \param stmt The stmt to do loop partition
305-
* \param split_const_loop flag to enable partition for const loop
306-
* \return Transformed stmt.
307-
*/
308-
Stmt LoopPartition(Stmt stmt, bool split_const_loop);
309-
310-
/*!
311-
* \brief Detect and insert sync points to co-processor.
312-
*
313-
* \param stmt The stmt to be transformed
314-
* \return Transformed stmt.
315-
*/
316-
Stmt CoProcSync(Stmt stmt);
317-
318-
/*!
319-
* \brief Lift common attrs with attr_key to outer scope.
320-
*
321-
* \param stmt The stmt to be transformed
322-
* \param attr_key The attribute key to be checked.
323-
* \return Transformed stmt.
324-
*/
325-
Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
326-
327-
/*!
328-
* \brief Detect and rewrite unsafe select that contains memory access.
329-
* \param stmt The statement to be rewritten.
330-
* \return Transformed stmt.
331-
*/
332-
Stmt RewriteUnsafeSelect(Stmt stmt);
333-
334-
/*!
335-
* \brief Lower attached storage access information.
336-
* Do this pass after all storage access analysis finish.
337-
*
338-
* \param stmt The stmt to be transformed
339-
* \return Transformed stmt.
340-
*/
341-
Stmt LowerStorageAccessInfo(Stmt stmt);
342-
343212
/*!
344213
* \brief Decorate the stmt with a device scope, this is helpful for
345214
* hardware accelerator without thread blocks.
@@ -356,15 +225,6 @@ Stmt DecorateDeviceScope(Stmt stmt);
356225
*/
357226
Stmt HoistIfThenElse(Stmt stmt);
358227

359-
/*!
360-
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
361-
* \note Run this pass after StorageFlatten.
362-
* \param stmt The stmt to do datatype rewrite
363-
* \param target_bits the bit of target datatype
364-
* \return Transformed stmt.
365-
*/
366-
Stmt NarrowDataType(Stmt stmt, int target_bits);
367-
368228
/*!
369229
* \brief Rewrite the pointer content type of arguments,
370230
* as well as Alloc internal to the function to use

include/tvm/tir/transform.h

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,124 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
5858
const std::string& name,
5959
const tvm::Array<runtime::String>& required);
6060

61+
/*!
62+
* \brief Inject copy intrinsics with optional pad.
63+
*
64+
* \param pragma_key The pragma key for hint of copy.
65+
* \param fintrin The function with signature
66+
*
67+
* Stmt fintrin(Buffer src,
68+
* Buffer dst,
69+
* Array<Expr> pad_before,
70+
* Array<Expr> pad_after,
71+
* Expr pad_value)
72+
* \return The pass.
73+
*/
74+
TVM_DLL Pass InjectCopyIntrin(std::string pragma_key,
75+
runtime::PackedFunc fintrin);
76+
77+
/*!
78+
* \brief Detect and insert sync points to co-processor.
79+
*
80+
* \return The pass.
81+
*/
82+
TVM_DLL Pass CoProcSync();
83+
84+
/*!
85+
* \brief Lift common attrs with attr_key to outer scope.
86+
*
87+
* \param attr_key The attribute key to be checked.
88+
* \return The pass.
89+
*/
90+
TVM_DLL Pass LiftAttrScope(std::string attr_key);
91+
92+
/*!
93+
* \brief partition loops in the stmt.
94+
*
95+
* \param split_const_loop flag to enable partition for const loop
96+
*
97+
* \return The pass.
98+
*/
99+
TVM_DLL Pass LoopPartition(bool split_const_loop);
100+
101+
/*!
102+
* \brief Lower vectorization loops.
103+
*
104+
* \param enable_vectorize Whether vectorization is enabled.
105+
*
106+
* \return The pass.
107+
*/
108+
TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
109+
110+
/*!
111+
* \brief Inject virtual thread loops.
112+
*
113+
* \return The pass.
114+
*/
115+
TVM_DLL Pass InjectVirtualThread();
116+
117+
/*!
118+
* \brief Inject double buffer statements.
119+
*
120+
* \param split_loop_factor Loop splitting factor.
121+
* \return The pass.
122+
*/
123+
TVM_DLL Pass InjectDoubleBuffer(int split_loop_factor);
124+
125+
/*!
126+
* \brief Rewrite storage allocation pattern.
127+
* Moves the allocation to outer most possible scope.
128+
* Trying to share space between allocations to make
129+
* a static allocation plan when possible.
130+
*
131+
* \return The pass.
132+
*/
133+
TVM_DLL Pass StorageRewrite();
134+
135+
/*!
136+
* \brief unroll the constant loop marked by unroll.
137+
* This pass also automatically attach pragma unroll tag to loops which meets the standard.
138+
*
139+
* \param auto_max_step The maximum step before stop attach automatic unroll
140+
* \param auto_max_depth The maximum depth before stop attach automatic unroll
141+
* \param auto_max_extent The maximum extent of the loop we can unroll,
142+
* this is an legacy option that do not take the loop total steps into account.
143+
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
144+
* \return The pass.
145+
*/
146+
TVM_DLL Pass UnrollLoop(int auto_max_step,
147+
int auto_max_depth,
148+
int auto_max_extent,
149+
bool explicit_unroll);
150+
151+
/*!
152+
* \brief Remove No Op from the Stmt.
153+
*
154+
* \return The pass.
155+
*/
156+
TVM_DLL Pass RemoveNoOp();
157+
158+
/*!
159+
* \brief Detect and rewrite unsafe select that contains memory access.
160+
*
161+
* \return The pass.
162+
*/
163+
TVM_DLL Pass RewriteUnsafeSelect();
164+
165+
/*!
166+
* \brief Run arithmetic simplifications on the statements and expressions.
167+
*
168+
* \return The pass.
169+
*/
170+
TVM_DLL Pass Simplify();
171+
172+
/*!
173+
* \brief Instruments bound checkers.
174+
*
175+
* \return The pass.
176+
*/
177+
TVM_DLL Pass InstrumentBoundCheckers();
178+
61179
/*!
62180
* \brief Transform the high-level PrimFunc to a low-level version
63181
* that can be used as an API function.

python/tvm/driver/build_module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def lower(sch,
179179
cfg.auto_unroll_max_depth,
180180
cfg.auto_unroll_max_extent,
181181
cfg.unroll_explicit)
182+
182183
for f in lower_phase2:
183184
stmt = f(stmt)
184185

@@ -187,11 +188,14 @@ def lower(sch,
187188
stmt = ir_pass.RemoveNoOp(stmt)
188189
if not cfg.disable_select_rewriting:
189190
stmt = ir_pass.RewriteUnsafeSelect(stmt)
191+
190192
for f in lower_phase3:
191193
stmt = f(stmt)
194+
192195
# Instrument BoundCheckers
193196
if cfg.instrument_bound_checkers:
194197
stmt = ir_pass.InstrumentBoundCheckers(stmt)
198+
195199
if simple_mode:
196200
return stmt
197201

0 commit comments

Comments
 (0)