Skip to content

Commit 1598e32

Browse files
wweicicemelon
authored andcommitted
Add EtaExpand to transform API (#3406)
* Add EtaExpand to transform API * Add test case
1 parent f7d15f6 commit 1598e32

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

python/tvm/relay/transform.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,15 @@ def ToANormalForm():
406406
"""
407407
return _transform.ToANormalForm()
408408

409+
def EtaExpand():
410+
"""Add abstraction over a function
411+
412+
Returns
413+
-------
414+
ret: tvm.relay.Pass
415+
The registered pass that eta expands an expression.
416+
"""
417+
return _transform.EtaExpand()
409418

410419
def ToGraphNormalForm():
411420
"""Turn A Normal Form expression into Graph Normal Form expression

src/relay/pass/eta_expand.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,20 @@ Expr EtaExpand(const Expr& e, const Module& mod) {
6767

6868
TVM_REGISTER_API("relay._ir_pass.eta_expand").set_body_typed(EtaExpand);
6969

70+
namespace transform {
71+
72+
Pass EtaExpand() {
73+
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
74+
[=](Function f, Module m, PassContext pc) {
75+
return Downcast<Function>(EtaExpand(f, m));
76+
};
77+
return CreateFunctionPass(pass_func, 1, "EtaExpand", {});
78+
}
79+
80+
TVM_REGISTER_API("relay._transform.EtaExpand")
81+
.set_body_typed(EtaExpand);
82+
83+
} // namespace transform
84+
7085
} // namespace relay
7186
} // namespace tvm

tests/python/relay/test_pass_eta_expand.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
from tvm import relay
18+
import tvm.relay.module as _module
19+
import tvm.relay.transform as _transform
1820

1921
def test_eta_expand_basic():
20-
mod = relay.Module()
2122
x = relay.var('x', 'int32')
22-
y = relay.var('y', 'int32')
2323
orig = relay.Function([x], x)
24-
got = relay.ir_pass.eta_expand(orig, mod)
24+
mod = _module.Module.from_expr(orig)
25+
seq = _transform.Sequential([_transform.EtaExpand()])
26+
with _transform.PassContext(opt_level=3):
27+
mod = seq(mod)
28+
29+
got = mod[mod.entry_func.name_hint]
30+
31+
y = relay.var('y', 'int32')
2532
expected = relay.Function([y], orig(y))
2633

2734
got = relay.ir_pass.infer_type(got, mod)

0 commit comments

Comments
 (0)