File tree Expand file tree Collapse file tree 3 files changed +23
-10
lines changed Expand file tree Collapse file tree 3 files changed +23
-10
lines changed Original file line number Diff line number Diff line change 2121namespace tvm {
2222namespace ir {
2323
24- using Halide::Internal::equal;
25- using Halide::Internal::simplify;
24+ inline bool Equal (Expr a, Expr b) {
25+ return Halide::Internal::equal (a, b);
26+ }
27+
28+ inline bool Equal (Stmt a, Stmt b) {
29+ return Halide::Internal::equal (a, b);
30+ }
31+
32+ inline Expr Simplify (Expr a) {
33+ return Halide::Internal::simplify (a);
34+ }
35+
36+ inline Stmt Simplify (Stmt a) {
37+ return Halide::Internal::simplify (a);
38+ }
2639
2740/* !
2841 * \brief Schedule s' dependent operations.
Original file line number Diff line number Diff line change @@ -17,20 +17,20 @@ TVM_REGISTER_API(_pass_Simplify)
1717.set_body([](const ArgStack& args, RetValue *ret) {
1818 CHECK (args.at (0 ).type_id == kNodeHandle );
1919 if (dynamic_cast <Expr::ContainerType*>(args.at (0 ).sptr .get ())) {
20- *ret = simplify (args.at (0 ).operator Expr ());
20+ *ret = Simplify (args.at (0 ).operator Expr ());
2121 } else {
22- *ret = simplify (args.at (0 ).operator Stmt ());
22+ *ret = Simplify (args.at (0 ).operator Stmt ());
2323 }
2424 });
2525
26- TVM_REGISTER_API (_pass_equal )
26+ TVM_REGISTER_API (_pass_Equal )
2727.set_body([](const ArgStack& args, RetValue *ret) {
2828 CHECK (args.at (0 ).type_id == kNodeHandle );
2929 CHECK (args.at (1 ).type_id == kNodeHandle );
3030 if (dynamic_cast <Expr::ContainerType*>(args.at (0 ).sptr .get ())) {
31- *ret = equal (args.at (0 ).operator Expr (), args.at (1 ).operator Expr ());
31+ *ret = Equal (args.at (0 ).operator Expr (), args.at (1 ).operator Expr ());
3232 } else {
33- *ret = equal (args.at (0 ).operator Stmt (), args.at (1 ).operator Stmt ());
33+ *ret = Equal (args.at (0 ).operator Stmt (), args.at (1 ).operator Stmt ());
3434 }
3535 });
3636
Original file line number Diff line number Diff line change 33def test_simplify ():
44 x = tvm .Var ('x' )
55 e1 = tvm .ir_pass .Simplify (x + 2 + 1 )
6- assert (tvm .ir_pass .equal (e1 , x + 3 ))
6+ assert (tvm .ir_pass .Equal (e1 , x + 3 ))
77 e2 = tvm .ir_pass .Simplify (x * 3 + 5 * x )
8- assert (tvm .ir_pass .equal (e2 , x * 8 ))
8+ assert (tvm .ir_pass .Equal (e2 , x * 8 ))
99 e3 = tvm .ir_pass .Simplify (x - x / 3 * 3 )
10- assert (tvm .ir_pass .equal (e3 , tvm .make .Mod (x , 3 )))
10+ assert (tvm .ir_pass .Equal (e3 , tvm .make .Mod (x , 3 )))
1111
1212
1313def test_verify_ssa ():
You can’t perform that action at this time.
0 commit comments