diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 5dcc003d395f..8beb2b6b5a58 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -40,6 +40,38 @@ class NormalizeMutator : public ExprMutatorBase { return builder_->Normalize(ExprMutatorBase::VisitExpr(expr)); } + Expr VisitExpr_(const FunctionNode* op) { + Expr body = this->VisitWithNewScope(op->body); + + if (body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->params, body, op->ret_type, op->ret_shape, op->attrs); + } + } + + Expr VisitExpr_(const IfNode* op) { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitWithNewScope(op->true_branch); + Expr false_b = this->VisitWithNewScope(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } + } + + Expr VisitWithNewScope(const Expr& expr) { + builder_->BeginBindingBlock(); + Expr ret = this->VisitExpr(expr); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + ret = SeqExpr({prologue}, ret); + } + return ret; + } + Expr VisitExpr_(const SeqExprNode* op) final { bool all_blocks_unchanged = true; Array blocks; diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index aec36836b1eb..b7f196d3997d 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -871,5 +871,24 @@ def k(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: check_shape(gv2_bind.var, ("n", "n")) +def test_class_normalize(): + @tvm.script.ir_module + class InputModule: + @R.function + def mul_add(x: Tensor) -> Tensor: + return R.multiply(R.add(x, x), R.add(x, x)) + + # The parser automatically normalizes the input AST to the following ANF form + @tvm.script.ir_module + class OutputModule: + @R.function + def mul_add(x: Tensor) -> Tensor: + gv = relax.add(x, x) + gv1 = relax.add(x, x) + return R.multiply(gv, gv1) + + assert_structural_equal(InputModule, OutputModule) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 7faba3ad678e..0a3272bfe883 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -492,9 +492,94 @@ def foo(x: Tensor((d,), "float32")): assert cast_expr.dtype == "int64" -def test_to_anf(): +def test_normalize_function(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + type_anno = relax.DynTensorType(ndim=2, dtype="float16") + x = relax.Var("x", [m, n], type_anno) + + # Note: the parser automatically normalize the IR written in TVMScript, + # so we manually construct the function here. + mul_add = relax.Function( + [x], + relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), + ret_type=type_anno, + ret_shape=relax.RuntimeDepShape(), + ) + mul_add = mul_add.with_attr("global_symbol", "mul_add") + before_mod = tvm.IRModule.from_expr(mul_add) + + after_mod = relax.transform.Normalize()(before_mod) + + @tvm.script.ir_module + class Expected: + @R.function + def mul_add(x: Tensor((m, n), "float16")) -> Tensor(None, "float16", ndim=2): + gv = R.add(x, x) + gv1 = relax.add(x, x) + return R.multiply(gv, gv1) + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_if(): + cond = relax.Var("cond", [], type_annotation=relax.DynTensorType(0, "bool")) + x = relax.Var("x", [tir.IntImm("int64", 1)], type_annotation=relax.DynTensorType(1, "float32")) + # TODO(relax-team): add type and shape inference for IfNode + y = relax.Var("y") + + # Note: the parser automatically normalize the IR written in TVMScript, + # so we manually construct the function and If here. + f = relax.Function( + [cond, x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + y, + relax.If( + cond, + relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), + relax.op.add(relax.op.multiply(x, x), relax.op.multiply(x, x)), + ), + ) + ] + ) + ], + y, + ), + ret_type=relax.DynTensorType(1, "float32"), + ret_shape=relax.RuntimeDepShape(), + ) + + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + @tvm.script.ir_module - class TestNormalizeInputModule: + class Expected: + @R.function + def f( + cond: Tensor((), "bool"), x: Tensor((1,), "float32") + ) -> Tensor(None, "float32", ndim=1): + if cond: + gv = R.add(x, x) + gv1 = R.add(x, x) + y = R.multiply(gv, gv1) + else: + gv = R.multiply(x, x) + gv1 = R.multiply(x, x) + y = R.add(gv, gv1) + return y + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_no_op(): + # the normalize pass should be no-op for IR in ANF + @tvm.script.ir_module + class ANFMod1: @R.function def f(x: Tensor(_, "float32")): gv = relax.add(x, x) @@ -502,14 +587,12 @@ def f(x: Tensor(_, "float32")): gv2 = relax.add(gv, gv1) return (gv, gv2) - before_mod = TestNormalizeInputModule + before_mod = ANFMod1 after_mod = relax.transform.Normalize()(before_mod) assert_structural_equal(before_mod, after_mod, map_free_vars=True) - -def test_to_anf_no_op(): @tvm.script.ir_module - class TestANFNoOp: + class ANFMod2: @R.function def foo(x: Tensor((m, n), "float32")): with relax.dataflow(): @@ -518,7 +601,7 @@ def foo(x: Tensor((m, n), "float32")): relax.output(gv0) return gv0 - mod = TestANFNoOp + mod = ANFMod2 mod_post = relax.transform.Normalize()(mod) assert_structural_equal(mod, mod_post)