Skip to content

Use GraphBuilder in test fusion ops. #11078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 162 additions & 133 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
import torch
from executorch.backends.cadence.aot import compiler
from executorch.backends.cadence.aot.fuse_ops import (
FuseCascadedTransposeOrPermuteOps,
FuseCascadedViewOps,
FuseFullThenReshapePass,
FuseMMWithAdd,
FuseMulScalarIntoDequantPass,
FuseMulTensorIntoDequantPass,
FuseQuantDequantToRequantizePass,
Expand All @@ -39,113 +42,133 @@ def check_op_counts(


class TestFusionPasses(TestFusionPassesBase):
def test_addmm_fusion(self):
class AddmmFeasible1(torch.nn.Module):
def forward(self, x, y, z):
t1 = torch.mm(x, y)
return torch.add(t1, z)

x = torch.randn(3, 5)
y = torch.randn(5, 6)
z = torch.randn(6)

graph_module = (
compiler.export_to_cadence(AddmmFeasible1(), (x, y, z))
.exported_program()
.graph_module
def test_fuse_mm_with_add(self):
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32))
z = builder.placeholder("z", torch.randn(6, dtype=torch.float32))
mm = builder.call_operator(
op=exir_ops.edge.aten.mm.default,
args=(x, y),
)
output = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z))
builder.output([output])
original_graph = builder.get_graph_module()
converted_graph = FuseMMWithAdd()(original_graph).graph_module
converted_graph.graph.eliminate_dead_code()
self.assertEqual(
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1
)
graph_module.graph.eliminate_dead_code()

# Assert that mm and add were fused to addmm
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1)
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0)
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0)

class AddmmFeasible2(torch.nn.Module):
def forward(self, x, y, z):
t1 = y.view((8, 6))
t2 = torch.mm(x, t1)
t3 = t2.view((2, 2, 6))
return torch.add(t3, z)

x = torch.randn(4, 8)
y = torch.randn(2, 4, 6)
z = torch.randn(6)
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0)
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0)

graph_module = (
compiler.export_to_cadence(AddmmFeasible2(), (x, y, z))
.exported_program()
.graph_module
def test_fuse_view_mm_view_add(self):
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 8, dtype=torch.float32))
y = builder.placeholder("y", torch.randn(2, 4, 6, dtype=torch.float32))
z = builder.placeholder("z", torch.randn(6, dtype=torch.float32))
y_view = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default, args=(y, [8, 6])
)
graph_module.graph.eliminate_dead_code()
# Assert that mm and add were fused to addmm
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1)
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0)
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0)

# Bias is a singleton value, broadcastable to output of mm
class AddmmFeasible3(torch.nn.Module):
def forward(self, x, y):
t1 = torch.mm(x, y)
return torch.add(t1, torch.ones(1))

x = torch.randn(3, 5)
y = torch.randn(5, 6)

graph_module = (
compiler.export_to_cadence(AddmmFeasible3(), (x, y))
.exported_program()
.graph_module
mm = builder.call_operator(
op=exir_ops.edge.aten.mm.default,
args=(x, y_view),
)
mm_view = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default, args=(mm, [2, 2, 6])
)
graph_module.graph.eliminate_dead_code()
# Assert that mm and add were fused to addmm
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1)
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0)
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0)
output = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor, args=(mm_view, z)
)
builder.output([output])
original_graph = builder.get_graph_module()
converted_graph = FuseMMWithAdd()(original_graph).graph_module
converted_graph.graph.eliminate_dead_code()
self.assertEqual(
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1
)
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0)
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0)

def test_keep_view_mm_view_add(self):
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 8, dtype=torch.float32))
y = builder.placeholder("y", torch.randn(2, 4, 6, dtype=torch.float32))
# Bias is not broadcastable to output of mm
class AddmmInfeasible1(torch.nn.Module):
def forward(self, x, y, z):
t1 = y.view((8, 6))
t2 = torch.mm(x, t1)
t3 = t2.view((2, 2, 6))
return torch.add(t3, z)

x = torch.randn(4, 8)
y = torch.randn(2, 4, 6)
z = torch.randn(2, 2, 1)

graph_module = (
compiler.export_to_cadence(AddmmInfeasible1(), (x, y, z))
.exported_program()
.graph_module
z = builder.placeholder("z", torch.randn(2, 2, 1, dtype=torch.float32))
y_view = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default, args=(y, [8, 6])
)
mm = builder.call_operator(
op=exir_ops.edge.aten.mm.default,
args=(x, y_view),
)
graph_module.graph.eliminate_dead_code()
mm_view = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default, args=(mm, [2, 2, 6])
)
output = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor, args=(mm_view, z)
)
builder.output([output])
original_graph = builder.get_graph_module()
converted_graph = FuseMMWithAdd()(original_graph).graph_module
converted_graph.graph.eliminate_dead_code()
# Assert that mm and add were not fused to addmm, since z cannot be
# broadcasted to the out of mm.
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 1)

# The add consuming the output of mm has more than one users.
class AddmmInfeasible2(torch.nn.Module):
def forward(self, x, y, z):
t1 = torch.mm(x, y)
t2 = torch.add(t1, z)
t3 = torch.add(t2, z)
return torch.add(t2, t3)
self.assertEqual(
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0
)
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1)
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 1)

x = torch.randn(3, 5)
y = torch.randn(5, 6)
z = torch.randn(6)
def test_fuse_mm_add_with_bias(self):
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32))
mm = builder.call_operator(
op=exir_ops.edge.aten.mm.default,
args=(x, y),
)
bias = builder.call_operator(op=exir_ops.edge.aten.full.default, args=([1], 1))
output = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor, args=(mm, bias)
)
builder.output([output])
original_graph = builder.get_graph_module()
converted_graph = FuseMMWithAdd()(original_graph).graph_module
converted_graph.graph.eliminate_dead_code()
self.assertEqual(
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1
)
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0)
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0)

graph_module = (
compiler.export_to_cadence(AddmmInfeasible2(), (x, y, z))
.exported_program()
.graph_module
def test_keep_mm_add_with_multiple_users(self):
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32))
z = builder.placeholder("z", torch.randn(6, dtype=torch.float32))
mm = builder.call_operator(
op=exir_ops.edge.aten.mm.default,
args=(x, y),
)
# The add consuming the output of mm has more than one users.
add1 = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z))
add2 = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(add1, z))
output = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor, args=(add1, add2)
)
graph_module.graph.eliminate_dead_code()
builder.output([output])
original_graph = builder.get_graph_module()
converted_graph = FuseMMWithAdd()(original_graph).graph_module
converted_graph.graph.eliminate_dead_code()
# Assert that mm and add were not fused to addmm, since add has multiple
# users.
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 3)
self.assertEqual(
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0
)
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1)
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3)

# TODO(matthiascremon): enable that pass with new flow
@torch.no_grad()
Expand Down Expand Up @@ -184,63 +207,69 @@ def forward(self, x):
)

def test_permute_transpose_fusion(self):
class PermuteTranspose(torch.nn.Module):
def forward(self, x):
y = x.permute((0, 2, 4, 1, 3))
return y.transpose(0, 1)

x = torch.randn(3, 1, 3, 1, 4)
graph_module = (
compiler.export_to_cadence(PermuteTranspose(), (x,))
.exported_program()
.graph_module
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32))
permute = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 4, 1, 3])
)
output = builder.call_operator(
op=exir_ops.edge.aten.transpose_copy.int,
args=(permute, 1, 0),
)
graph_module.graph.eliminate_dead_code()
builder.output(output)
original_graph = builder.get_graph_module()
converted_graph = FuseCascadedTransposeOrPermuteOps()(
original_graph
).graph_module
converted_graph.graph.eliminate_dead_code()
# Assert that permute op was fused with transpose op
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 1
count_node(converted_graph, exir_ops.edge.aten.permute_copy.default), 1
)
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.transpose_copy.int), 0
count_node(converted_graph, exir_ops.edge.aten.transpose_copy.int), 0
)

def test_view_fusion(self):
class ViewFusion(torch.nn.Module):
def forward(self, x):
x = x.view([1, 8, 15])
x = x.view([1, 1, 120])
return x.view([1, 12, 10])

x = torch.randn(8, 5, 3)
graph_module = (
compiler.export_to_cadence(ViewFusion(), (x,))
.exported_program()
.graph_module
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32))
view1 = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15])
)
view2 = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default, args=(view1, [1, 1, 120])
)
output = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default, args=(view2, [1, 12, 10])
)
graph_module.graph.eliminate_dead_code()
builder.output(output)
original_graph = builder.get_graph_module()
converted_graph = FuseCascadedViewOps()(original_graph).graph_module
converted_graph.graph.eliminate_dead_code()
# Assert that only one view op remains
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 1
count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 1
)

def test_view_fusion_branched(self):
class ViewFusion(torch.nn.Module):
def forward(self, x):
y = x.view([1, 8, 15])
z = y.view([1, 1, 120])
t = y.view([120, 1, 1])
return z, t

x = torch.randn(8, 5, 3)
graph_module = (
compiler.export_to_cadence(ViewFusion(), (x,))
.exported_program()
.graph_module
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32))
y = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15])
)
z = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default, args=(y, [1, 1, 120])
)
graph_module.graph.eliminate_dead_code()
t = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default, args=(y, [120, 1, 1])
)
builder.output([z, t])
original_graph = builder.get_graph_module()
converted_graph = FuseCascadedViewOps()(original_graph).graph_module
converted_graph.graph.eliminate_dead_code()
# z and t should be fused and y should be eliminated.
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 2
count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 2
)

def test_force_quant_dequant_fusion(self):
Expand Down
Loading