Skip to content

Commit 2428ae2

Browse files
yongwwwjunrushao
authored andcommitted
[Pass] Lambda Lifting (#99)
1 parent 96467a8 commit 2428ae2

File tree

22 files changed

+987
-80
lines changed

22 files changed

+987
-80
lines changed

include/tvm/relax/analysis.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include <tvm/ir/diagnostic.h>
2828
#include <tvm/ir/module.h>
29+
#include <tvm/relax/expr.h>
2930
#include <tvm/relay/op_attr_types.h>
3031
#include <tvm/tir/function.h>
3132

@@ -53,6 +54,60 @@ TVM_DLL bool WellFormed(const IRModule& m,
5354
*/
5455
TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func);
5556

57+
/*!
58+
* \brief Get all bound variables from expression expr.
59+
*
60+
* Bound variables are all variables that are declared in the expr.
61+
* They only have meaning inside that expr, and can only be used in it.
62+
*
63+
* \param expr the expression.
64+
*
65+
* \return List of bound vars, in the PostDFS order in the expression.
66+
*/
67+
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
68+
69+
/*!
70+
* \brief Get free type parameters from expression expr.
71+
*
72+
* Free variables are variables that are not bound by a
73+
* varbinding or a function parameter in the context.
74+
*
75+
* \param expr the expression.
76+
*
77+
* \return List of free vars, in the PostDFS order in the expression.
78+
*/
79+
TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
80+
81+
/*!
82+
* \brief Get all variables from expression expr.
83+
*
84+
* \param expr the expression.
85+
*
86+
* \return List of all vars, in the PostDFS order in the expression.
87+
*/
88+
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
89+
90+
/*!
91+
* \brief Get all glabal variables for recursive call from expression expr.
92+
*
93+
* \param expr the expression.
94+
*
95+
* \return List of all global variables for recursive call.
96+
*/
97+
TVM_DLL tvm::Array<GlobalVar> RecGlobalVars(const Expr& expr);
98+
99+
/*!
100+
* \brief Get all glabal variables from expression expr.
101+
*
102+
* AllVars is a superset of BoundVars and FreeVars.
103+
* The union of BoundVars and FreeVars is Allvars.
104+
*
105+
* \param expr the expression.
106+
*
107+
* \return List of all global variables, in the PostDFS order in the expression.
108+
*/
109+
TVM_DLL tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr);
110+
56111
} // namespace relax
57112
} // namespace tvm
58113

include/tvm/relax/transform.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ TVM_DLL Pass FailTestRewrite();
8181
*/
8282
TVM_DLL Pass FMARewrite();
8383

84+
/*!
85+
* \brief Perform lambda lifting to lift functions from nested into global.
86+
*
87+
* \return The Pass.
88+
*/
89+
TVM_DLL Pass LambdaLift();
90+
8491
/*!
8592
* \brief Transform all dataflow structure to non-dataflow version.
8693
*

include/tvm/relax/utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,22 @@ class NameTable {
5757
std::unordered_map<std::string, uint32_t> alloc_map_;
5858
};
5959

60+
/*!
61+
* \brief Bind the variables to a Relax expression. This is a helper
62+
* function usually called by other pass functions to help optimizations.
63+
* If any free variables are introduced into a function, those are added
64+
* to the function parameters.
65+
* Additionally this may change the order of parameters if you map a variable
66+
* to a variable.
67+
*
68+
* \param expr The input expression.
69+
* \param binds The variable to expression map that will be used to help the
70+
* binding.
71+
*
72+
* \return The updated expression.
73+
*/
74+
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
75+
6076
} // namespace relax
6177
} // namespace tvm
6278

python/tvm/relax/expr.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,16 @@ def create_unchecked(
204204
"""Construct a relax.Function but without type checking."""
205205
return _ffi_api.Function_CreateUnchecked(params, body, ret_type, attrs, span)
206206

207+
def __call__(self, *args):
208+
"""Invoke the global function.
209+
210+
Parameters
211+
----------
212+
args: List[relax.Expr]
213+
Arguments.
214+
"""
215+
return Call(self, args, None, None)
216+
207217

208218
@tvm._ffi.register_object("relax.expr.ExternFunc")
209219
class ExternFunc(BaseFunc):

python/tvm/relax/transform/transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ def FuseFMA() -> tvm.ir.transform.Pass:
7070
return _ffi_api.FuseFMA()
7171

7272

73+
def LambdaLift():
74+
"""
75+
Lift local functions into global.
76+
77+
Returns
78+
-------
79+
ret : tvm.ir.transform.Pass
80+
"""
81+
return _ffi_api.LambdaLift()
82+
83+
7384
def ToNonDataflow() -> tvm.ir.transform.Pass:
7485
"""Transform all dataflow structure to non-dataflow version.
7586

python/tvm/script/relax/parser.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -920,8 +920,7 @@ def transform_stmt(
920920

921921
elif isinstance(stmt, ast.Function):
922922
func = self.transform_function(stmt)
923-
func_var = self.decl_var(stmt.name, None, None, stmt.span)
924-
return relax.VarBinding(func_var, func, self.to_tvm_span(stmt.span))
923+
return func
925924

926925
else:
927926
self.report_error(
@@ -1559,8 +1558,15 @@ def transform_block(self, block: ast.Block) -> relax.SeqExpr:
15591558
blocks.append(relax.BindingBlock(current_block, self.to_tvm_span(stmt.span)))
15601559
current_block = []
15611560
blocks.append(parsed_stmt)
1561+
elif isinstance(parsed_stmt, (relax.Function, tir.PrimFunc)):
1562+
func_var = self.decl_var(stmt.name, None, None, stmt.span)
1563+
current_block.append(
1564+
relax.VarBinding(func_var, parsed_stmt, self.to_tvm_span(stmt.span))
1565+
)
15621566
else:
1563-
assert isinstance(parsed_stmt, relax.Binding)
1567+
assert isinstance(
1568+
parsed_stmt, relax.Binding
1569+
), "Expected relax.Binding, but got " + str(type(parsed_stmt))
15641570
current_block.append(parsed_stmt)
15651571
if len(current_block) > 0:
15661572
blocks.append(relax.BindingBlock(current_block, self.to_tvm_span(block.stmts[-1].span)))
@@ -1573,6 +1579,19 @@ def transform_block(self, block: ast.Block) -> relax.SeqExpr:
15731579
)
15741580
ret_expr = self.transform_stmt(ret_stmt)
15751581

1582+
# only a call node in the function body
1583+
if isinstance(ret_expr, relax.Call) and len(blocks) == 0:
1584+
return ret_expr
1585+
1586+
# return a defined inner function
1587+
if (
1588+
len(blocks) > 0
1589+
and isinstance(blocks[-1].bindings[-1].value, relax.Function)
1590+
and hasattr(ret_expr, "name_hint")
1591+
and ret_expr.name_hint == blocks[-1].bindings[-1].var.name_hint
1592+
):
1593+
return blocks[-1].bindings[-1].value
1594+
15761595
return relax.SeqExpr(blocks, ret_expr, self.to_tvm_span(block.span))
15771596

15781597

src/relax/analysis/analysis.cc

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
*
22+
* \file analysis.cc
23+
*
24+
* \brief Analysis functions for Relax.
25+
*/
26+
27+
#include <tvm/relax/analysis.h>
28+
#include <tvm/relax/expr_functor.h>
29+
30+
namespace tvm {
31+
namespace relax {
32+
33+
template <typename T>
34+
struct InsertionSet {
35+
std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual> set;
36+
std::vector<T> data;
37+
void Insert(const T& t) {
38+
if (set.count(t) == 0) {
39+
set.insert(t);
40+
data.push_back(t);
41+
}
42+
}
43+
};
44+
45+
class VarVisitor : protected ExprVisitor {
46+
public:
47+
Array<Var> Free(const Expr& expr) {
48+
this->VisitExpr(expr);
49+
Array<Var> ret;
50+
for (const auto& v : vars_.data) {
51+
if (bound_vars_.set.count(v) == 0) {
52+
ret.push_back(v);
53+
}
54+
}
55+
return ret;
56+
}
57+
58+
Array<Var> Collect() {
59+
Array<Var> ret;
60+
for (const auto& v : bound_vars_.data) {
61+
ret.push_back(v);
62+
}
63+
return ret;
64+
}
65+
66+
Array<Var> Bound(const Expr& expr) {
67+
this->VisitExpr(expr);
68+
return Collect();
69+
}
70+
71+
Array<Var> All(const Expr& expr) {
72+
this->VisitExpr(expr);
73+
Array<Var> ret;
74+
for (const auto& v : vars_.data) {
75+
ret.push_back(v);
76+
}
77+
return ret;
78+
}
79+
80+
Array<GlobalVar> AllGlobalVars(const Expr& expr) {
81+
this->VisitExpr(expr);
82+
Array<GlobalVar> ret;
83+
for (const auto& v : global_vars_.data) {
84+
ret.push_back(v);
85+
}
86+
return ret;
87+
}
88+
89+
Array<GlobalVar> RecGlobalVars(const Expr& expr) {
90+
this->VisitExpr(expr);
91+
Array<GlobalVar> ret;
92+
for (const auto& v : rec_global_vars_.data) {
93+
ret.push_back(v);
94+
}
95+
return ret;
96+
}
97+
98+
void MarkBounded(const Var& v) {
99+
bound_vars_.Insert(v);
100+
vars_.Insert(v);
101+
}
102+
103+
void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
104+
105+
void VisitExpr_(const FunctionNode* op) final {
106+
for (const auto& param : op->params) {
107+
MarkBounded(param);
108+
}
109+
VisitExpr(op->body);
110+
}
111+
void VisitExpr_(const GlobalVarNode* op) final { global_vars_.Insert(GetRef<GlobalVar>(op)); }
112+
113+
void VisitExpr_(const CallNode* call_node) final {
114+
VisitSpan(call_node->span);
115+
VisitExpr(call_node->op);
116+
117+
for (Type ty_arg : call_node->type_args) {
118+
VisitType(ty_arg);
119+
}
120+
121+
for (Expr arg : call_node->args) {
122+
VisitExpr(arg);
123+
}
124+
125+
if (call_node->shape_) {
126+
VisitExpr(Downcast<Expr>(call_node->shape_.value()));
127+
}
128+
129+
if (const GlobalVarNode* global_var_node = call_node->op.as<GlobalVarNode>()) {
130+
rec_global_vars_.Insert(GetRef<GlobalVar>(global_var_node));
131+
}
132+
}
133+
134+
void VisitBinding_(const VarBindingNode* binding) final {
135+
MarkBounded(binding->var);
136+
VisitExpr(binding->value);
137+
VisitVarDef(binding->var);
138+
}
139+
140+
private:
141+
InsertionSet<Var> vars_;
142+
InsertionSet<Var> bound_vars_;
143+
InsertionSet<GlobalVar> global_vars_;
144+
InsertionSet<GlobalVar> rec_global_vars_;
145+
};
146+
147+
tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); }
148+
149+
tvm::Array<Var> BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); }
150+
151+
tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }
152+
153+
tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); }
154+
155+
tvm::Array<GlobalVar> RecGlobalVars(const Expr& expr) { return VarVisitor().RecGlobalVars(expr); }
156+
157+
TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars);
158+
159+
TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars);
160+
161+
TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars);
162+
163+
TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars);
164+
165+
TVM_REGISTER_GLOBAL("relax.analysis.rec_global_vars").set_body_typed(RecGlobalVars);
166+
167+
} // namespace relax
168+
} // namespace tvm

0 commit comments

Comments
 (0)