Skip to content

Commit 3914946

Browse files
committed
[PASS] Export simplify and equal to python
1 parent 0992873 commit 3914946

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

include/tvm/ir_pass.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#ifndef TVM_IR_PASS_H_
1010
#define TVM_IR_PASS_H_
1111

12+
#include <ir/IREquality.h>
13+
#include <pass/Simplify.h>
1214
#include <tvm/ir_functor.h>
1315
#include <unordered_map>
1416
#include <vector>
@@ -19,6 +21,8 @@
1921
namespace tvm {
2022
namespace ir {
2123

24+
using Halide::Internal::equal;
25+
using Halide::Internal::simplify;
2226

2327
/*!
2428
* \brief Schedule s' dependent operations.

src/c_api/c_api_pass.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,27 @@ namespace ir {
1313
using ArgStack = const std::vector<APIVariantValue>;
1414
using RetValue = APIVariantValue;
1515

16+
TVM_REGISTER_API(_pass_Simplify)
17+
.set_body([](const ArgStack& args, RetValue *ret) {
18+
CHECK(args.at(0).type_id == kNodeHandle);
19+
if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
20+
*ret = simplify(args.at(0).operator Expr());
21+
} else {
22+
*ret = simplify(args.at(0).operator Stmt());
23+
}
24+
});
25+
26+
TVM_REGISTER_API(_pass_equal)
27+
.set_body([](const ArgStack& args, RetValue *ret) {
28+
CHECK(args.at(0).type_id == kNodeHandle);
29+
CHECK(args.at(1).type_id == kNodeHandle);
30+
if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
31+
*ret = equal(args.at(0).operator Expr(), args.at(1).operator Expr());
32+
} else {
33+
*ret = equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
34+
}
35+
});
36+
1637
// make from two arguments
1738
#define REGISTER_PASS1(PassName) \
1839
TVM_REGISTER_API(_pass_## PassName) \

tests/python/test_pass_basic.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
import tvm
22

3+
def test_simplify():
4+
x = tvm.Var('x')
5+
e1 = tvm.ir_pass.Simplify(x + 2 + 1)
6+
assert(tvm.ir_pass.equal(e1, x + 3))
7+
e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x)
8+
assert(tvm.ir_pass.equal(e2, x * 8))
9+
e3 = tvm.ir_pass.Simplify(x - x / 3 * 3)
10+
assert(tvm.ir_pass.equal(e3, tvm.make.Mod(x, 3)))
11+
12+
313
def test_verify_ssa():
414
x = tvm.Var('x')
515
y = tvm.Var()

0 commit comments

Comments
 (0)