Skip to content

Commit

Permalink
[Pass] Support Function and If in Normalize pass. (apache#268)
Browse files Browse the repository at this point in the history
* Support Function and If in Normalize pass.

* Use structural equality for expr_memo_.

* Change back to pointer equality for expr_memo_; Add more tests.

* rebase.
  • Loading branch information
YuchenJin authored and junrushao committed Oct 16, 2022
1 parent ceb6048 commit b46ced7
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 7 deletions.
32 changes: 32 additions & 0 deletions src/relax/transform/normalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,38 @@ class NormalizeMutator : public ExprMutatorBase {
return builder_->Normalize(ExprMutatorBase::VisitExpr(expr));
}

Expr VisitExpr_(const FunctionNode* op) {
Expr body = this->VisitWithNewScope(op->body);

if (body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Function(op->params, body, op->ret_type, op->ret_shape, op->attrs);
}
}

Expr VisitExpr_(const IfNode* op) {
Expr guard = this->VisitExpr(op->cond);
Expr true_b = this->VisitWithNewScope(op->true_branch);
Expr false_b = this->VisitWithNewScope(op->false_branch);
if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) &&
op->false_branch.same_as(false_b)) {
return GetRef<Expr>(op);
} else {
return If(guard, true_b, false_b, op->span);
}
}

Expr VisitWithNewScope(const Expr& expr) {
builder_->BeginBindingBlock();
Expr ret = this->VisitExpr(expr);
BindingBlock prologue = builder_->EndBlock();
if (!prologue->bindings.empty()) {
ret = SeqExpr({prologue}, ret);
}
return ret;
}

Expr VisitExpr_(const SeqExprNode* op) final {
bool all_blocks_unchanged = true;
Array<BindingBlock> blocks;
Expand Down
19 changes: 19 additions & 0 deletions tests/python/relax/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,5 +871,24 @@ def k(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor:
check_shape(gv2_bind.var, ("n", "n"))


def test_class_normalize():
@tvm.script.ir_module
class InputModule:
@R.function
def mul_add(x: Tensor) -> Tensor:
return R.multiply(R.add(x, x), R.add(x, x))

# The parser automatically normalizes the input AST to the following ANF form
@tvm.script.ir_module
class OutputModule:
@R.function
def mul_add(x: Tensor) -> Tensor:
gv = relax.add(x, x)
gv1 = relax.add(x, x)
return R.multiply(gv, gv1)

assert_structural_equal(InputModule, OutputModule)


if __name__ == "__main__":
pytest.main([__file__])
97 changes: 90 additions & 7 deletions tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,24 +492,107 @@ def foo(x: Tensor((d,), "float32")):
assert cast_expr.dtype == "int64"


def test_to_anf():
def test_normalize_function():
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
type_anno = relax.DynTensorType(ndim=2, dtype="float16")
x = relax.Var("x", [m, n], type_anno)

# Note: the parser automatically normalize the IR written in TVMScript,
# so we manually construct the function here.
mul_add = relax.Function(
[x],
relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)),
ret_type=type_anno,
ret_shape=relax.RuntimeDepShape(),
)
mul_add = mul_add.with_attr("global_symbol", "mul_add")
before_mod = tvm.IRModule.from_expr(mul_add)

after_mod = relax.transform.Normalize()(before_mod)

@tvm.script.ir_module
class Expected:
@R.function
def mul_add(x: Tensor((m, n), "float16")) -> Tensor(None, "float16", ndim=2):
gv = R.add(x, x)
gv1 = relax.add(x, x)
return R.multiply(gv, gv1)

assert_structural_equal(after_mod, Expected)


def test_normalize_if():
cond = relax.Var("cond", [], type_annotation=relax.DynTensorType(0, "bool"))
x = relax.Var("x", [tir.IntImm("int64", 1)], type_annotation=relax.DynTensorType(1, "float32"))
# TODO(relax-team): add type and shape inference for IfNode
y = relax.Var("y")

# Note: the parser automatically normalize the IR written in TVMScript,
# so we manually construct the function and If here.
f = relax.Function(
[cond, x],
relax.SeqExpr(
[
relax.BindingBlock(
[
relax.VarBinding(
y,
relax.If(
cond,
relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)),
relax.op.add(relax.op.multiply(x, x), relax.op.multiply(x, x)),
),
)
]
)
],
y,
),
ret_type=relax.DynTensorType(1, "float32"),
ret_shape=relax.RuntimeDepShape(),
)

f = f.with_attr("global_symbol", "f")
before_mod = tvm.IRModule.from_expr(f)
after_mod = relax.transform.Normalize()(before_mod)

@tvm.script.ir_module
class TestNormalizeInputModule:
class Expected:
@R.function
def f(
cond: Tensor((), "bool"), x: Tensor((1,), "float32")
) -> Tensor(None, "float32", ndim=1):
if cond:
gv = R.add(x, x)
gv1 = R.add(x, x)
y = R.multiply(gv, gv1)
else:
gv = R.multiply(x, x)
gv1 = R.multiply(x, x)
y = R.add(gv, gv1)
return y

assert_structural_equal(after_mod, Expected)


def test_normalize_no_op():
# the normalize pass should be no-op for IR in ANF
@tvm.script.ir_module
class ANFMod1:
@R.function
def f(x: Tensor(_, "float32")):
gv = relax.add(x, x)
gv1 = relax.add(gv, gv)
gv2 = relax.add(gv, gv1)
return (gv, gv2)

before_mod = TestNormalizeInputModule
before_mod = ANFMod1
after_mod = relax.transform.Normalize()(before_mod)
assert_structural_equal(before_mod, after_mod, map_free_vars=True)


def test_to_anf_no_op():
@tvm.script.ir_module
class TestANFNoOp:
class ANFMod2:
@R.function
def foo(x: Tensor((m, n), "float32")):
with relax.dataflow():
Expand All @@ -518,7 +601,7 @@ def foo(x: Tensor((m, n), "float32")):
relax.output(gv0)
return gv0

mod = TestANFNoOp
mod = ANFMod2
mod_post = relax.transform.Normalize()(mod)

assert_structural_equal(mod, mod_post)
Expand Down

0 comments on commit b46ced7

Please sign in to comment.