Skip to content

Commit 14b15f3

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Use GraphBuilder in test fusion ops. (#11078)
Summary: Pull Request resolved: #11078 Reviewed By: hsharma35 Differential Revision: D75183327
1 parent df5e7df commit 14b15f3

File tree

1 file changed

+162
-133
lines changed

1 file changed

+162
-133
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 162 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
import torch
1515
from executorch.backends.cadence.aot import compiler
1616
from executorch.backends.cadence.aot.fuse_ops import (
17+
FuseCascadedViewOps,
1718
FuseFullThenReshapePass,
19+
FuseMMWithAdd,
1820
FuseMulScalarIntoDequantPass,
1921
FuseMulTensorIntoDequantPass,
2022
FuseQuantDequantToRequantizePass,
@@ -39,113 +41,133 @@ def check_op_counts(
3941

4042

4143
class TestFusionPasses(TestFusionPassesBase):
42-
def test_addmm_fusion(self):
43-
class AddmmFeasible1(torch.nn.Module):
44-
def forward(self, x, y, z):
45-
t1 = torch.mm(x, y)
46-
return torch.add(t1, z)
47-
48-
x = torch.randn(3, 5)
49-
y = torch.randn(5, 6)
50-
z = torch.randn(6)
51-
52-
graph_module = (
53-
compiler.export_to_cadence(AddmmFeasible1(), (x, y, z))
54-
.exported_program()
55-
.graph_module
44+
def test_fuse_mm_with_add(self):
45+
builder = GraphBuilder()
46+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
47+
y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32))
48+
z = builder.placeholder("z", torch.randn(6, dtype=torch.float32))
49+
mm = builder.call_operator(
50+
op=exir_ops.edge.aten.mm.default,
51+
args=(x, y),
52+
)
53+
output = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z))
54+
builder.output([output])
55+
original_graph = builder.get_graph_module()
56+
converted_graph = FuseMMWithAdd()(original_graph).graph_module
57+
converted_graph.graph.eliminate_dead_code()
58+
self.assertEqual(
59+
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1
5660
)
57-
graph_module.graph.eliminate_dead_code()
58-
59-
# Assert that mm and add were fused to addmm
60-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1)
61-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0)
62-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0)
63-
64-
class AddmmFeasible2(torch.nn.Module):
65-
def forward(self, x, y, z):
66-
t1 = y.view((8, 6))
67-
t2 = torch.mm(x, t1)
68-
t3 = t2.view((2, 2, 6))
69-
return torch.add(t3, z)
70-
71-
x = torch.randn(4, 8)
72-
y = torch.randn(2, 4, 6)
73-
z = torch.randn(6)
61+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0)
62+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0)
7463

75-
graph_module = (
76-
compiler.export_to_cadence(AddmmFeasible2(), (x, y, z))
77-
.exported_program()
78-
.graph_module
64+
def test_fuse_view_mm_view_add(self):
65+
builder = GraphBuilder()
66+
x = builder.placeholder("x", torch.randn(4, 8, dtype=torch.float32))
67+
y = builder.placeholder("y", torch.randn(2, 4, 6, dtype=torch.float32))
68+
z = builder.placeholder("z", torch.randn(6, dtype=torch.float32))
69+
y_view = builder.call_operator(
70+
op=exir_ops.edge.aten.view_copy.default, args=(y, [8, 6])
7971
)
80-
graph_module.graph.eliminate_dead_code()
81-
# Assert that mm and add were fused to addmm
82-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1)
83-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0)
84-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0)
85-
86-
# Bias is a singleton value, broadcastable to output of mm
87-
class AddmmFeasible3(torch.nn.Module):
88-
def forward(self, x, y):
89-
t1 = torch.mm(x, y)
90-
return torch.add(t1, torch.ones(1))
91-
92-
x = torch.randn(3, 5)
93-
y = torch.randn(5, 6)
94-
95-
graph_module = (
96-
compiler.export_to_cadence(AddmmFeasible3(), (x, y))
97-
.exported_program()
98-
.graph_module
72+
mm = builder.call_operator(
73+
op=exir_ops.edge.aten.mm.default,
74+
args=(x, y_view),
75+
)
76+
mm_view = builder.call_operator(
77+
op=exir_ops.edge.aten.view_copy.default, args=(mm, [2, 2, 6])
9978
)
100-
graph_module.graph.eliminate_dead_code()
101-
# Assert that mm and add were fused to addmm
102-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1)
103-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0)
104-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0)
79+
output = builder.call_operator(
80+
op=exir_ops.edge.aten.add.Tensor, args=(mm_view, z)
81+
)
82+
builder.output([output])
83+
original_graph = builder.get_graph_module()
84+
converted_graph = FuseMMWithAdd()(original_graph).graph_module
85+
converted_graph.graph.eliminate_dead_code()
86+
self.assertEqual(
87+
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1
88+
)
89+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0)
90+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0)
10591

92+
def test_keep_view_mm_view_add(self):
93+
builder = GraphBuilder()
94+
x = builder.placeholder("x", torch.randn(4, 8, dtype=torch.float32))
95+
y = builder.placeholder("y", torch.randn(2, 4, 6, dtype=torch.float32))
10696
# Bias is not broadcastable to output of mm
107-
class AddmmInfeasible1(torch.nn.Module):
108-
def forward(self, x, y, z):
109-
t1 = y.view((8, 6))
110-
t2 = torch.mm(x, t1)
111-
t3 = t2.view((2, 2, 6))
112-
return torch.add(t3, z)
113-
114-
x = torch.randn(4, 8)
115-
y = torch.randn(2, 4, 6)
116-
z = torch.randn(2, 2, 1)
117-
118-
graph_module = (
119-
compiler.export_to_cadence(AddmmInfeasible1(), (x, y, z))
120-
.exported_program()
121-
.graph_module
97+
z = builder.placeholder("z", torch.randn(2, 2, 1, dtype=torch.float32))
98+
y_view = builder.call_operator(
99+
op=exir_ops.edge.aten.view_copy.default, args=(y, [8, 6])
100+
)
101+
mm = builder.call_operator(
102+
op=exir_ops.edge.aten.mm.default,
103+
args=(x, y_view),
122104
)
123-
graph_module.graph.eliminate_dead_code()
105+
mm_view = builder.call_operator(
106+
op=exir_ops.edge.aten.view_copy.default, args=(mm, [2, 2, 6])
107+
)
108+
output = builder.call_operator(
109+
op=exir_ops.edge.aten.add.Tensor, args=(mm_view, z)
110+
)
111+
builder.output([output])
112+
original_graph = builder.get_graph_module()
113+
converted_graph = FuseMMWithAdd()(original_graph).graph_module
114+
converted_graph.graph.eliminate_dead_code()
124115
# Assert that mm and add were not fused to addmm, since z cannot be
125116
# broadcasted to the out of mm.
126-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 1)
127-
128-
# The add consuming the output of mm has more than one users.
129-
class AddmmInfeasible2(torch.nn.Module):
130-
def forward(self, x, y, z):
131-
t1 = torch.mm(x, y)
132-
t2 = torch.add(t1, z)
133-
t3 = torch.add(t2, z)
134-
return torch.add(t2, t3)
117+
self.assertEqual(
118+
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0
119+
)
120+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1)
121+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 1)
135122

136-
x = torch.randn(3, 5)
137-
y = torch.randn(5, 6)
138-
z = torch.randn(6)
123+
def test_fuse_mm_add_with_bias(self):
124+
builder = GraphBuilder()
125+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
126+
y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32))
127+
mm = builder.call_operator(
128+
op=exir_ops.edge.aten.mm.default,
129+
args=(x, y),
130+
)
131+
bias = builder.call_operator(op=exir_ops.edge.aten.full.default, args=([1], 1))
132+
output = builder.call_operator(
133+
op=exir_ops.edge.aten.add.Tensor, args=(mm, bias)
134+
)
135+
builder.output([output])
136+
original_graph = builder.get_graph_module()
137+
converted_graph = FuseMMWithAdd()(original_graph).graph_module
138+
converted_graph.graph.eliminate_dead_code()
139+
self.assertEqual(
140+
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1
141+
)
142+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0)
143+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0)
139144

140-
graph_module = (
141-
compiler.export_to_cadence(AddmmInfeasible2(), (x, y, z))
142-
.exported_program()
143-
.graph_module
145+
def test_keep_mm_add_with_multiple_users(self):
146+
builder = GraphBuilder()
147+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
148+
y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32))
149+
z = builder.placeholder("z", torch.randn(6, dtype=torch.float32))
150+
mm = builder.call_operator(
151+
op=exir_ops.edge.aten.mm.default,
152+
args=(x, y),
153+
)
154+
# The add consuming the output of mm has more than one users.
155+
add1 = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z))
156+
add2 = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(add1, z))
157+
output = builder.call_operator(
158+
op=exir_ops.edge.aten.add.Tensor, args=(add1, add2)
144159
)
145-
graph_module.graph.eliminate_dead_code()
160+
builder.output([output])
161+
original_graph = builder.get_graph_module()
162+
converted_graph = FuseMMWithAdd()(original_graph).graph_module
163+
converted_graph.graph.eliminate_dead_code()
146164
# Assert that mm and add were not fused to addmm, since add has multiple
147165
# users.
148-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 3)
166+
self.assertEqual(
167+
count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0
168+
)
169+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1)
170+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3)
149171

150172
# TODO(matthiascremon): enable that pass with new flow
151173
@torch.no_grad()
@@ -184,63 +206,70 @@ def forward(self, x):
184206
)
185207

186208
def test_permute_transpose_fusion(self):
187-
class PermuteTranspose(torch.nn.Module):
188-
def forward(self, x):
189-
y = x.permute((0, 2, 4, 1, 3))
190-
return y.transpose(0, 1)
191-
192-
x = torch.randn(3, 1, 3, 1, 4)
193-
graph_module = (
194-
compiler.export_to_cadence(PermuteTranspose(), (x,))
195-
.exported_program()
196-
.graph_module
209+
builder = GraphBuilder()
210+
x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32))
211+
permute = builder.call_operator(
212+
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 4, 1, 3])
213+
)
214+
output = builder.call_operator(
215+
op=exir_ops.edge.aten.transpose_copy.int,
216+
args=(permute, 0, 1),
197217
)
198-
graph_module.graph.eliminate_dead_code()
218+
builder.output(output)
219+
original_graph = builder.get_graph_module()
220+
# Question: This pass can not be applied because [0, 2, 4] != [2, 0, 4] in can_fuse_for_chain. Do I use the right pass?
221+
converted_graph = FuseTransposeOrPermuteOpPairsPass()(
222+
original_graph
223+
).graph_module
224+
converted_graph.graph.eliminate_dead_code()
199225
# Assert that permute op was fused with transpose op
200226
self.assertEqual(
201-
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 1
227+
count_node(converted_graph, exir_ops.edge.aten.permute_copy.default), 1
202228
)
203229
self.assertEqual(
204-
count_node(graph_module, exir_ops.edge.aten.transpose_copy.int), 0
230+
count_node(converted_graph, exir_ops.edge.aten.transpose_copy.int), 0
205231
)
206232

207233
def test_view_fusion(self):
208-
class ViewFusion(torch.nn.Module):
209-
def forward(self, x):
210-
x = x.view([1, 8, 15])
211-
x = x.view([1, 1, 120])
212-
return x.view([1, 12, 10])
213-
214-
x = torch.randn(8, 5, 3)
215-
graph_module = (
216-
compiler.export_to_cadence(ViewFusion(), (x,))
217-
.exported_program()
218-
.graph_module
234+
builder = GraphBuilder()
235+
x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32))
236+
view1 = builder.call_operator(
237+
op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15])
238+
)
239+
view2 = builder.call_operator(
240+
op=exir_ops.edge.aten.view_copy.default, args=(view1, [1, 1, 120])
241+
)
242+
output = builder.call_operator(
243+
op=exir_ops.edge.aten.view_copy.default, args=(view2, [1, 12, 10])
219244
)
220-
graph_module.graph.eliminate_dead_code()
245+
builder.output(output)
246+
original_graph = builder.get_graph_module()
247+
converted_graph = FuseCascadedViewOps()(original_graph).graph_module
248+
converted_graph.graph.eliminate_dead_code()
221249
# Assert that only one view op remains
222250
self.assertEqual(
223-
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 1
251+
count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 1
224252
)
225253

226254
def test_view_fusion_branched(self):
227-
class ViewFusion(torch.nn.Module):
228-
def forward(self, x):
229-
y = x.view([1, 8, 15])
230-
z = y.view([1, 1, 120])
231-
t = y.view([120, 1, 1])
232-
return z, t
233-
234-
x = torch.randn(8, 5, 3)
235-
graph_module = (
236-
compiler.export_to_cadence(ViewFusion(), (x,))
237-
.exported_program()
238-
.graph_module
255+
builder = GraphBuilder()
256+
x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32))
257+
y = builder.call_operator(
258+
op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15])
259+
)
260+
z = builder.call_operator(
261+
op=exir_ops.edge.aten.view_copy.default, args=(y, [1, 1, 120])
239262
)
240-
graph_module.graph.eliminate_dead_code()
263+
t = builder.call_operator(
264+
op=exir_ops.edge.aten.view_copy.default, args=(y, [120, 1, 1])
265+
)
266+
builder.output([z, t])
267+
original_graph = builder.get_graph_module()
268+
converted_graph = FuseCascadedViewOps()(original_graph).graph_module
269+
converted_graph.graph.eliminate_dead_code()
241270
# z and t should be fused and y should be eliminated.
242271
self.assertEqual(
243-
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 2
272+
count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 2
244273
)
245274

246275
def test_force_quant_dequant_fusion(self):

0 commit comments

Comments
 (0)