Skip to content

Commit 9572d98

Browse files
yidawangzhiics
authored andcommitted
[Fix] Fix the logic of the number of nodes checking in op fusion (#4074)
* move the number of nodes constraint in op fusion up to the dom tree level * add test case of limiting the max number of ops to be fused * uncomment other test cases
1 parent 283afac commit 9572d98

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

src/relay/pass/fuse_ops.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -623,9 +623,7 @@ class GraphPartitioner {
623623
* \param parent The parent group.
624624
*/
625625
void MergeFromTo(Group* child, Group* parent) {
626-
// refuse the fusion if too many ops are going to be fused together
627-
if (child->num_nodes + parent->num_nodes > kMaxFusedOps)
628-
return;
626+
// update the number of nodes of the parent group
629627
parent->num_nodes += child->num_nodes;
630628
child = child->FindRoot();
631629
parent = parent->FindRoot();
@@ -701,6 +699,10 @@ class GraphPartitioner {
701699
CHECK(!graph_node->extern_ref);
702700
size_t dom_parent_gindex = dom_node->parent->gnode->index;
703701

702+
// refuse the fusion if too many ops are going to be fused together
703+
if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps)
704+
continue;
705+
704706
if (phase == 2) {
705707
// Fuse injective ops into intermediate tuples, if any
706708
if (group_node->pattern > kInjective) continue;

tests/python/relay/test_pass_fuse_ops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,39 @@ def test_split():
552552
mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c)
553553
mod = transform.FuseOps()(mod)
554554

555+
def test_fuse_max():
556+
"""Test the constraint of number of nodes in op fusion."""
557+
max_fused_ops = 256
558+
# n is the number of nodes to be fused, should be less than 2*max_fused_ops
559+
n = 300
560+
def before():
561+
x = relay.var("x", shape=(10, 20))
562+
y = x
563+
for i in range(n):
564+
y = relay.exp(y)
565+
return relay.Function([x], y)
566+
567+
def expected():
568+
x = relay.var("p", shape=(10, 20))
569+
y = x
570+
for i in range(max_fused_ops):
571+
y = relay.exp(y)
572+
f1 = relay.Function([x], y)
573+
x = relay.var("x", shape=(10, 20))
574+
z = relay.Call(f1, [x])
575+
xx = relay.var("pp", shape=(10, 20))
576+
yy = xx
577+
for i in range(n-max_fused_ops):
578+
yy = relay.exp(yy)
579+
f2 = relay.Function([xx], yy)
580+
zz = relay.Call(f2, [z])
581+
return relay.Function([x], zz)
582+
583+
z = before()
584+
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
585+
zz = run_opt_pass(z, transform.FuseOps())
586+
after = run_opt_pass(expected(), transform.InferType())
587+
assert relay.analysis.alpha_equal(zz, after)
555588

556589
if __name__ == "__main__":
557590
test_fuse_simple()
@@ -568,3 +601,4 @@ def test_split():
568601
test_fuse_parallel_injective()
569602
test_immutable()
570603
test_split()
604+
test_fuse_max()

0 commit comments

Comments
 (0)