File tree Expand file tree Collapse file tree 3 files changed +35
-0
lines changed Expand file tree Collapse file tree 3 files changed +35
-0
lines changed Original file line number Diff line number Diff line change 99#ifndef TVM_IR_PASS_H_
1010#define TVM_IR_PASS_H_
1111
12+ #include < ir/IREquality.h>
13+ #include < pass/Simplify.h>
1214#include < tvm/ir_functor.h>
1315#include < unordered_map>
1416#include < vector>
1921namespace tvm {
2022namespace ir {
2123
24+ using Halide::Internal::equal;
25+ using Halide::Internal::simplify;
2226
2327/* !
2428 * \brief Schedule s' dependent operations.
Original file line number Diff line number Diff line change @@ -13,6 +13,27 @@ namespace ir {
1313using ArgStack = const std::vector<APIVariantValue>;
1414using RetValue = APIVariantValue;
1515
16+ TVM_REGISTER_API (_pass_Simplify)
17+ .set_body([](const ArgStack& args, RetValue *ret) {
18+ CHECK (args.at (0 ).type_id == kNodeHandle );
19+ if (dynamic_cast <Expr::ContainerType*>(args.at (0 ).sptr .get ())) {
20+ *ret = simplify (args.at (0 ).operator Expr ());
21+ } else {
22+ *ret = simplify (args.at (0 ).operator Stmt ());
23+ }
24+ });
25+
26+ TVM_REGISTER_API (_pass_equal)
27+ .set_body([](const ArgStack& args, RetValue *ret) {
28+ CHECK (args.at (0 ).type_id == kNodeHandle );
29+ CHECK (args.at (1 ).type_id == kNodeHandle );
30+ if (dynamic_cast <Expr::ContainerType*>(args.at (0 ).sptr .get ())) {
31+ *ret = equal (args.at (0 ).operator Expr (), args.at (1 ).operator Expr ());
32+ } else {
33+ *ret = equal (args.at (0 ).operator Stmt (), args.at (1 ).operator Stmt ());
34+ }
35+ });
36+
1637// make from two arguments
1738#define REGISTER_PASS1 (PassName ) \
1839 TVM_REGISTER_API (_pass_## PassName) \
Original file line number Diff line number Diff line change 11import tvm
22
3+ def test_simplify ():
4+ x = tvm .Var ('x' )
5+ e1 = tvm .ir_pass .Simplify (x + 2 + 1 )
6+ assert (tvm .ir_pass .equal (e1 , x + 3 ))
7+ e2 = tvm .ir_pass .Simplify (x * 3 + 5 * x )
8+ assert (tvm .ir_pass .equal (e2 , x * 8 ))
9+ e3 = tvm .ir_pass .Simplify (x - x / 3 * 3 )
10+ assert (tvm .ir_pass .equal (e3 , tvm .make .Mod (x , 3 )))
11+
12+
313def test_verify_ssa ():
414 x = tvm .Var ('x' )
515 y = tvm .Var ()
You can’t perform that action at this time.
0 commit comments