From afa20869b4e0c0f3de272892a4ce531e75b01f06 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 11 Aug 2017 18:17:08 -0700 Subject: [PATCH] [PASS] More improvement of canonical (#314) --- src/arithmetic/canonical.cc | 3 +++ tests/python/unittest/test_pass_simplify.py | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 1a48779b79e3..c904b92d8ccb 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -191,6 +191,9 @@ class Canonical::Internal : public IRMutator { ret_entry_.max_level = stack_.back().max_level; stack_.pop_back(); CHECK(expr.defined()); + if (const IntImm* op = expr.as()) { + return Mutate_(op, expr); + } return expr; } // call produce to get a cache entry. diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py index c6cf79d153b4..2cc8825e37f3 100644 --- a/tests/python/unittest/test_pass_simplify.py +++ b/tests/python/unittest/test_pass_simplify.py @@ -27,6 +27,16 @@ def test_basic(): assert str(ret.value) == "(m - 1)" +def test_canonical(): + x = tvm.var("x") + z = tvm.const(3) + ret = tvm.ir_pass.CanonicalSimplify(x / (z*z) - x / (z*z)) + assert(tvm.ir_pass.Equal(ret, 0)) + + ret = tvm.ir_pass.CanonicalSimplify(x / (z+z) - x / (z+z)) + assert(tvm.ir_pass.Equal(ret, 0)) + if __name__ == "__main__": test_basic() test_simplify() + test_canonical()