Skip to content

Commit 6cb5b88

Browse files
tqchenSiyuan Feng
andauthored
[TIR] Enhance Substitute, python bindings for Substitute/PostOrderVisit/IRTransform. (#5400)
Substitute now takes a std::function to customize more replacing behaviors. Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn> Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn>
1 parent 8c0f779 commit 6cb5b88

34 files changed

+419
-317
lines changed

docs/api/python/tir.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,10 @@ tvm.tir.analysis
3838
:members:
3939
:imported-members:
4040
:autosummary:
41+
42+
43+
tvm.tir.stmt_functor
44+
--------------------
45+
.. automodule:: tvm.tir.stmt_functor
46+
:members:
47+
:autosummary:

include/tvm/runtime/container.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,10 @@ struct PackedFuncValueConverter<::tvm::runtime::String> {
611611
}
612612
};
613613

614+
/*! \brief Helper to represent nullptr for optional. */
615+
struct NullOptType {
616+
};
617+
614618
/*!
615619
* \brief Optional container that to represent to a Nullable variant of T.
616620
* \tparam T The original ObjectRef.
@@ -642,6 +646,8 @@ class Optional : public ObjectRef {
642646
* \param ptr
643647
*/
644648
explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
649+
/*! \brief Nullopt handling */
650+
Optional(NullOptType) {} // NOLINT(*)
645651
// nullptr handling.
646652
// disallow implicit conversion as 0 can be implicitly converted to nullptr_t
647653
explicit Optional(std::nullptr_t) {}
@@ -751,6 +757,7 @@ struct PackedFuncValueConverter<Optional<T>> {
751757
// expose the functions to the root namespace.
752758
using runtime::String;
753759
using runtime::Optional;
760+
constexpr runtime::NullOptType NullOpt{};
754761
} // namespace tvm
755762

756763
namespace std {

include/tvm/tir/ir_pass.h

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -81,40 +81,6 @@ bool ExprUseVar(const PrimExpr& e, const std::unordered_set<const VarNode*>& vse
8181
*/
8282
TVM_DLL Stmt ConvertSSA(Stmt stmt);
8383

84-
/*!
85-
* \brief Substitute the var specified in key->var to be value.
86-
* \param stmt The source statement to be substituted
87-
* \param value_map The map of new values.
88-
* \return The converted form.
89-
*/
90-
Stmt Substitute(Stmt stmt,
91-
const std::unordered_map<const VarNode*, PrimExpr>& value_map);
92-
93-
/*!
94-
* \brief Substitute the var specified in key->var to be value.
95-
* \param expr The source expression to be substituted
96-
* \param value_map The map of new values.
97-
* \return The converted expression.
98-
*/
99-
PrimExpr Substitute(PrimExpr expr,
100-
const std::unordered_map<const VarNode*, PrimExpr>& value_map);
101-
102-
/*!
103-
* \brief Substitute the var specified in key->var to be value.
104-
* \param stmt The source statement to be substituted
105-
* \param value_map The map of new values.
106-
* \return The converted form.
107-
*/
108-
Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map);
109-
110-
/*!
111-
* \brief Substitute the var specified in key->var to be value.
112-
* \param expr The source expression to be substituted
113-
* \param value_map The map of new values.
114-
* \return The converted expression.
115-
*/
116-
PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map);
117-
11884
/*!
11985
* \brief Verify if there is any argument bound to compact buffer.
12086
*

include/tvm/tir/stmt_functor.h

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,20 @@
2020
/*!
2121
* \file tvm/tir/stmt_functor.h
2222
*
23-
* \brief Functors for tir stmts.
23+
* \brief Functors for tir stmts
24+
* utility functions to call common functors.
2425
*/
2526
#ifndef TVM_TIR_STMT_FUNCTOR_H_
2627
#define TVM_TIR_STMT_FUNCTOR_H_
2728

2829
#include <tvm/node/functor.h>
30+
#include <tvm/node/container.h>
2931
#include <tvm/tir/expr.h>
3032
#include <tvm/tir/stmt.h>
3133
#include <tvm/tir/expr_functor.h>
3234

3335
#include <utility>
36+
#include <unordered_map>
3437

3538
namespace tvm {
3639
namespace tir {
@@ -318,33 +321,86 @@ class StmtExprMutator :
318321
};
319322

320323
/*!
321-
* \brief recursively visit the ir in post DFS order node, and transform it
324+
* \brief recursively visit the ir nodes in post DFS order, and transform it
322325
*
323-
* \param node The ir to be transformed.
326+
* \param stmt The ir to be transformed.
324327
* \param preorder The function called in before recursive mutation
325328
* If preorder returns None, then the transform will proceed to recursive call.
326329
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
327330
* won't do further recursion.
328331
* \param postorder The function called after recursive mutation.
329332
* The recursive mutation result is passed to postorder for further mutation.
330333
* \param only_enable List of runtime::String.
331-
* If it is empty, all IRNode will call preorder/postorder
332-
* If it is not empty, preorder/postorder will only be called
334+
* If it is null, all IRNode will call preorder/postorder
335+
* If it is not null, preorder/postorder will only be called
333336
* when the IRNode's type key is in the list.
334337
*/
335-
TVM_DLL Stmt IRTransform(Stmt node,
338+
TVM_DLL Stmt IRTransform(Stmt stmt,
336339
const runtime::PackedFunc& preorder,
337340
const runtime::PackedFunc& postorder,
338-
const Array<runtime::String>& only_enable = {});
341+
Optional<Array<String>> only_enable = NullOpt);
339342

340343
/*!
341-
* \brief recursively visit the ir in post DFS order node, apply fvisit
344+
* \brief Recursively visit the ir in post DFS order node, apply fvisit
342345
* Each node is guaranteed to be visited only once.
343346
* \param node The ir to be visited.
344347
* \param fvisit The visitor function to be applied.
345348
*/
346349
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
347350

351+
/*!
352+
* \brief Substitute the var specified by vmap.
353+
* \param stmt The source statement to be substituted
354+
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
355+
* \return The converted form.
356+
*/
357+
TVM_DLL Stmt Substitute(Stmt stmt,
358+
std::function<Optional<PrimExpr>(const Var& var)> vmap);
359+
360+
/*!
361+
* \brief Substitute the var specified by vmap.
362+
* \param expr The source statement to be substituted
363+
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
364+
* \return The result.
365+
*/
366+
TVM_DLL PrimExpr Substitute(PrimExpr expr,
367+
std::function<Optional<PrimExpr>(const Var& var)> vmap);
368+
369+
/*!
370+
* \brief Sugar for substitute via a given map.
371+
* \param input The input to be updated.
372+
* \param value_map The map of new values.
373+
* \return The result.
374+
* \tparam T the input type, can be PrimExpr or Stmt.
375+
*/
376+
template<typename T>
377+
inline T Substitute(T input, const Map<Var, PrimExpr>& value_map) {
378+
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
379+
auto it = value_map.find(var);
380+
if (it != value_map.end()) return (*it).second;
381+
return Optional<PrimExpr>(nullptr);
382+
};
383+
return Substitute(std::move(input), vmap);
384+
}
385+
386+
/*!
387+
* \brief Sugar for substitute via a given map.
388+
* \param input The input to be updated.
389+
* \param value_map The map of new values.
390+
* \return The result.
391+
* \tparam T the input type, can be PrimExpr or Stmt.
392+
*/
393+
template<typename T>
394+
inline T Substitute(T input,
395+
const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
396+
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
397+
auto it = value_map.find(var.get());
398+
if (it != value_map.end()) return (*it).second;
399+
return Optional<PrimExpr>(nullptr);
400+
};
401+
return Substitute(std::move(input), vmap);
402+
}
403+
348404
} // namespace tir
349405
} // namespace tvm
350406

python/tvm/te/hybrid/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _pruned_source(func):
7272
def replace_io(body, rmap):
7373
"""Replacing tensors usage according to the dict given"""
7474
# pylint: disable=import-outside-toplevel
75-
from tvm.tir import ir_pass
75+
from tvm.tir import stmt_functor
7676

7777
def replace(op):
7878
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
@@ -84,7 +84,7 @@ def replace(op):
8484
_expr.Call.Halide, buf.op, buf.value_index)
8585
return None
8686

87-
return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])
87+
return stmt_functor.ir_transform(body, None, replace, ['Provide', 'Call'])
8888

8989

9090
def _is_tvm_arg_types(args):

python/tvm/tir/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@
4848
from . import ir_pass
4949
from . import transform
5050
from . import analysis
51+
from . import stmt_functor

python/tvm/tir/stmt_functor.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Statement functor utilities for IR transformations"""
18+
from . import _ffi_api
19+
20+
21+
def ir_transform(stmt, preorder, postorder, only_enable=None):
22+
"""Recursively visit and transform ir nodes in post DFS order.
23+
24+
Parameters
25+
----------
26+
stmt : Stmt
27+
The input to be transformed.
28+
29+
preorder: function
30+
The function called in before recursive mutation
31+
If preorder returns None, then the transform will proceed to recursive call.
32+
If preorder returns a not None Stmt/Expr, the transformer will simply return it and
33+
won't do further recursion.
34+
35+
postorder : function
36+
The function called after recursive mutation.
37+
38+
only_enable : Optional[List[str]]
39+
List of types that we only enable.
40+
41+
Returns
42+
-------
43+
result : Stmt
44+
The result.
45+
"""
46+
return _ffi_api.IRTransform(stmt, preorder, postorder, only_enable)
47+
48+
49+
def post_order_visit(stmt, fvisit):
50+
"""Recursively visit the ir in post DFS order node, apply fvisit
51+
Each node is guaranteed to be visited only once.
52+
53+
Parameters
54+
----------
55+
fvisit: function
56+
The visitor function.
57+
"""
58+
return _ffi_api.PostOrderVisit(stmt, fvisit)
59+
60+
61+
def substitute(node, vmap):
62+
""" Substitute the var specified by vmap.
63+
64+
Parameters
65+
----------
66+
node: ObjectRef
67+
The input.
68+
69+
vmap : Dict[Var, PrimExpr]
70+
The variable mapping.
71+
72+
Returns
73+
-------
74+
result : Stmt
75+
The result.
76+
"""
77+
return _ffi_api.Substitute(node, vmap)

src/arith/solve_linear_equation.cc

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
#include <tvm/arith/analyzer.h>
2727
#include <tvm/arith/int_solver.h>
2828
#include <tvm/arith/util.h>
29-
#include <tvm/tir/op.h>
3029
#include <tvm/arith/pattern.h>
31-
#include <tvm/tir/ir_pass.h>
30+
31+
#include <tvm/tir/op.h>
32+
#include <tvm/tir/stmt_functor.h>
3233
#include <tvm/runtime/data_type.h>
3334

3435
namespace tvm {
@@ -130,10 +131,10 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
130131
(*S)[i][j] = new_i_j;
131132
}
132133
// We have to do the same with rhs
133-
PrimExpr ea = te::make_const((*y)[index].dtype(), a);
134-
PrimExpr eb = te::make_const((*y)[i].dtype(), b);
135-
PrimExpr e_m_g = te::make_const((*y)[i].dtype(), m_g);
136-
PrimExpr e_n_g = te::make_const((*y)[index].dtype(), n_g);
134+
PrimExpr ea = tir::make_const((*y)[index].dtype(), a);
135+
PrimExpr eb = tir::make_const((*y)[i].dtype(), b);
136+
PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g);
137+
PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g);
137138
PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i];
138139
PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i];
139140
(*y)[index] = new_index_rhs;
@@ -190,10 +191,10 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
190191
(*V)[i][j] = new_i_j;
191192
}
192193
// And apply reverse transformations to new_to_old.
193-
PrimExpr ea = te::make_const((*x)[j].dtype(), a);
194-
PrimExpr eb = te::make_const((*x)[index].dtype(), b);
195-
PrimExpr e_m_g = te::make_const((*x)[index].dtype(), m_g);
196-
PrimExpr e_n_g = te::make_const((*x)[j].dtype(), n_g);
194+
PrimExpr ea = tir::make_const((*x)[j].dtype(), a);
195+
PrimExpr eb = tir::make_const((*x)[index].dtype(), b);
196+
PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g);
197+
PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g);
197198
PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j];
198199
PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j];
199200
(*x)[index] = new_index;
@@ -369,7 +370,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
369370
IntConstraints(
370371
/*variables=*/{},
371372
/*ranges=*/{},
372-
/*relations=*/{te::make_zero(DataType::Bool())}),
373+
/*relations=*/{tir::make_zero(DataType::Bool())}),
373374
{}, {});
374375
} else if (!tir::is_const_int(new_relation, 1)) {
375376
new_relations.push_back(new_relation);
@@ -403,13 +404,13 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
403404
// The j-th variable is just a single value, don't create a tvm variable
404405
// S^{-1}_{nxm} Uy_{mxn}
405406
if (S[j][j] >= 0) {
406-
PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]);
407+
PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]);
407408
solution_for_V_inv_x.push_back(
408409
analyzer_problem.Simplify(floordiv(Uy[j], a)));
409410
} else {
410411
// This is required because some simplifiers
411412
// have problems with dividing by negative numbers
412-
PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]);
413+
PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]);
413414
solution_for_V_inv_x.push_back(
414415
analyzer_problem.Simplify(floordiv(-Uy[j], a)));
415416
}
@@ -418,9 +419,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
418419

419420
// V V^{-1} x = x
420421
for (size_t i = 0; i < num_vars; ++i) {
421-
PrimExpr e = te::make_zero(system_to_solve->variables[i].dtype());
422+
PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype());
422423
for (size_t j = 0; j < num_vars; ++j) {
423-
e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
424+
e = e + tir::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
424425
}
425426
e = analyzer_problem.Simplify(e);
426427
old_to_new_map.Set(system_to_solve->variables[i], e);

src/te/autodiff/ad_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
* \brief Utility for tensor-level auto-differentiation.
2323
*/
2424
#include <tvm/tir/expr.h>
25-
#include <tvm/tir/ir_pass.h>
25+
#include <tvm/tir/stmt_functor.h>
2626
#include <string>
2727
#include "ad_util.h"
2828

src/te/operation/hybrid_op.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include <tvm/arith/analyzer.h>
2727
#include <tvm/tir/expr.h>
2828
#include <tvm/tir/stmt_functor.h>
29-
#include <tvm/tir/ir_pass.h>
3029
#include <tvm/tir/analysis.h>
3130
#include <tvm/tir/op.h>
3231
#include <unordered_set>

0 commit comments

Comments
 (0)