Skip to content

Commit 6c6cf65

Browse files
committed
Update on "Introduce hydra framework with backwards compatibility"
[ghstack-poisoned]
2 parents 44fa6dc + 00d3b62 commit 6c6cf65

13 files changed

+1402
-165
lines changed

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 99 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
1515
import torch.nn as nn
16-
import torch.nn.functional as F
17-
from executorch.backends.cadence.aot import compiler
1816
from executorch.backends.cadence.aot.compiler import export_to_edge
1917
from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass
2018
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
@@ -53,16 +51,15 @@ class TestRemoveOpsPasses(unittest.TestCase):
5351
)
5452
@torch.no_grad()
5553
def test_remove_to_ops(self, shape: Tuple[int]):
56-
class M(torch.nn.Module):
57-
def forward(self, x: torch.Tensor):
58-
return exir_ops.edge.aten.to(x, dtype=torch.float32)
59-
60-
model = M()
61-
x = torch.randn(shape)
62-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
63-
p = RemoveToOpsPass()
64-
65-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
54+
builder = GraphBuilder()
55+
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
56+
x = builder.call_operator(
57+
op=exir_ops.edge.aten.to.dtype,
58+
args=(x, torch.float32),
59+
)
60+
builder.output([x])
61+
original = builder.get_graph_module()
62+
graph_after_passes = cast(PassResult, RemoveToOpsPass()(original)).graph_module
6663

6764
self.assertEqual(
6865
count_node(graph_after_passes, exir_ops.edge.aten.to.dtype),
@@ -83,31 +80,24 @@ def forward(self, x: torch.Tensor):
8380
)
8481
@torch.no_grad()
8582
def test_remove_nop_add_op_pass(self, shape: Tuple[int]):
86-
class FullX(torch.nn.Module):
87-
def forward(self, t: torch.Tensor):
88-
return torch.add(torch.full(shape, 0), t)
89-
90-
class FullY(torch.nn.Module):
91-
def forward(self, t: torch.Tensor):
92-
return torch.add(t, torch.full(shape, 0))
93-
94-
model = FullX()
95-
t = torch.full(shape, 3)
96-
graph_module = export_to_edge(model, (t,)).exported_program().graph_module
97-
98-
p = RemoveNopAddOpPass()
99-
100-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
101-
self.assertEqual(
102-
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
103-
0,
104-
)
105-
106-
model = FullY()
107-
graph_module = export_to_edge(model, (t,)).exported_program().graph_module
108-
109-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
110-
83+
builder = GraphBuilder()
84+
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
85+
zeros = builder.call_operator(
86+
op=exir_ops.edge.aten.full.default, args=(shape, 0)
87+
)
88+
left_add = builder.call_operator(
89+
op=exir_ops.edge.aten.add.Tensor,
90+
args=(zeros, x),
91+
)
92+
right_add = builder.call_operator(
93+
op=exir_ops.edge.aten.add.Tensor,
94+
args=(left_add, zeros),
95+
)
96+
builder.output([right_add])
97+
original = builder.get_graph_module()
98+
graph_after_passes = cast(
99+
PassResult, RemoveNopAddOpPass()(original)
100+
).graph_module
111101
self.assertEqual(
112102
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
113103
0,
@@ -122,31 +112,24 @@ def forward(self, t: torch.Tensor):
122112
)
123113
@torch.no_grad()
124114
def test_remove_nop_mul_op_pass(self, shape: Tuple[int]):
125-
class FullX(torch.nn.Module):
126-
def forward(self, t: torch.Tensor):
127-
return torch.mul(torch.full(shape, 0), t)
128-
129-
class FullY(torch.nn.Module):
130-
def forward(self, t: torch.Tensor):
131-
return torch.mul(t, torch.full(shape, 0))
132-
133-
model = FullX()
134-
t = torch.full(shape, 3)
135-
graph_module = export_to_edge(model, (t,)).exported_program().graph_module
136-
137-
p = RemoveNopMulOpPass()
138-
139-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
140-
self.assertEqual(
141-
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
142-
0,
143-
)
144-
145-
model = FullY()
146-
graph_module = export_to_edge(model, (t,)).exported_program().graph_module
147-
148-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
149-
115+
builder = GraphBuilder()
116+
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
117+
zeros = builder.call_operator(
118+
op=exir_ops.edge.aten.full.default, args=(shape, 0)
119+
)
120+
left_mul = builder.call_operator(
121+
op=exir_ops.edge.aten.mul.Tensor,
122+
args=(zeros, x),
123+
)
124+
right_mul = builder.call_operator(
125+
op=exir_ops.edge.aten.mul.Tensor,
126+
args=(left_mul, zeros),
127+
)
128+
builder.output([right_mul])
129+
original = builder.get_graph_module()
130+
graph_after_passes = cast(
131+
PassResult, RemoveNopMulOpPass()(original)
132+
).graph_module
150133
self.assertEqual(
151134
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
152135
0,
@@ -159,18 +142,16 @@ def forward(self, t: torch.Tensor):
159142
)
160143
@torch.no_grad()
161144
def test_remove_alias_copy(self, shape: Tuple[int]):
162-
class M(torch.nn.Module):
163-
def forward(self, x: torch.Tensor):
164-
return exir_ops.edge.aten.alias_copy(x)
165-
166-
model = M()
167-
x = torch.randn(shape)
168-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
169-
170-
p = RemoveAliasCopyOpPass()
171-
172-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
173-
145+
builder = GraphBuilder()
146+
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
147+
alias = builder.call_operator(
148+
op=exir_ops.edge.aten.alias_copy.default, args=(x,)
149+
)
150+
builder.output([alias])
151+
original = builder.get_graph_module()
152+
graph_after_passes = cast(
153+
PassResult, RemoveAliasCopyOpPass()(original)
154+
).graph_module
174155
self.assertEqual(
175156
count_node(graph_after_passes, exir_ops.edge.aten.alias_copy.default),
176157
0,
@@ -183,19 +164,16 @@ def forward(self, x: torch.Tensor):
183164
)
184165
@torch.no_grad()
185166
def test_remove_detach_copy(self, shape: Tuple[int]):
186-
# aten::detach is converted to aten::alias_copy after functionalization & decomposition.
187-
class M(torch.nn.Module):
188-
def forward(self, x: torch.Tensor):
189-
return exir_ops.edge.aten.detach_copy(x)
190-
191-
model = M()
192-
x = torch.randn(shape)
193-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
194-
195-
p = RemoveDetachCopyPass()
196-
197-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
198-
167+
builder = GraphBuilder()
168+
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
169+
detach = builder.call_operator(
170+
op=exir_ops.edge.aten.detach_copy.default, args=(x,)
171+
)
172+
builder.output([detach])
173+
original = builder.get_graph_module()
174+
graph_after_passes = cast(
175+
PassResult, RemoveDetachCopyPass()(original)
176+
).graph_module
199177
self.assertEqual(
200178
count_node(graph_after_passes, exir_ops.edge.aten.detach_copy.default),
201179
0,
@@ -210,95 +188,51 @@ def forward(self, x: torch.Tensor):
210188
def test_remove_zero_sized_constant_pad_nd(
211189
self, shape: Tuple[int], padding: Tuple[int]
212190
):
213-
# F.pad is converted to aten::constant_pad_nd after functionalization & decomposition.
214-
class Padding(torch.nn.Module):
215-
def __init__(self):
216-
super().__init__()
217-
self.padding = padding
218-
219-
def forward(self, x: torch.Tensor):
220-
return F.pad(x, self.padding)
221-
222-
model = Padding()
223-
x = torch.randn(shape)
224-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
225-
226-
p = RemoveZeroSizedConstantPadNd()
227-
228-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
229-
191+
builder = GraphBuilder()
192+
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
193+
pad = builder.call_operator(
194+
op=exir_ops.edge.aten.constant_pad_nd.default, args=(x, padding)
195+
)
196+
builder.output([pad])
197+
original = builder.get_graph_module()
198+
graph_after_passes = cast(
199+
PassResult, RemoveZeroSizedConstantPadNd()(original)
200+
).graph_module
230201
self.assertEqual(
231202
count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default),
232203
0,
233204
)
234205

235206
def test_remove_expand(self):
236-
class Expand(torch.nn.Module):
237-
def forward(self, x):
238-
return torch.ops.aten.expand_copy(x, [2, 3, 5])
239-
240-
x = torch.ones(2, 3, 5)
241-
p = RemoveNopExpandOpPass()
242-
graph_module = export_to_edge(Expand(), (x,)).exported_program().graph_module
243-
graph_module = p(graph_module).graph_module
244-
# Assert that expand op is optimized away, since it is a nop
207+
builder = GraphBuilder()
208+
x = builder.placeholder("x", torch.randn([2, 3, 5], dtype=torch.float32))
209+
expand = builder.call_operator(
210+
op=exir_ops.edge.aten.expand_copy.default, args=(x, [2, 3, 5])
211+
)
212+
builder.output([expand])
213+
original = builder.get_graph_module()
214+
graph_after_passes = cast(
215+
PassResult, RemoveNopExpandOpPass()(original)
216+
).graph_module
245217
self.assertEqual(
246-
count_node(graph_module, exir_ops.edge.aten.expand_copy.default), 0
218+
count_node(graph_after_passes, exir_ops.edge.aten.expand_copy.default), 0
247219
)
248220

249221
def test_remove_zero_arg_cat(self):
250-
class Cat(torch.nn.Module):
251-
def forward(self, x, y):
252-
return torch.ops.aten.cat((x, y), 0)
253-
254-
x = torch.ones(1, 0, 3, 5)
255-
y = torch.ones(2, 0, 3, 5)
256-
graph_module = (
257-
compiler.export_to_cadence(Cat(), (x, y)).exported_program().graph_module
258-
)
259-
# Assert that cat op is optimized away, since it concatenates
260-
# two zero-sized tensors
261-
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
262-
263-
def test_remove_single_arg_cat(self):
264-
class Cat(torch.nn.Module):
265-
def forward(self, x, y):
266-
z = torch.ones(0, 5)
267-
# z is an empty tensor, and concatenation of x with z will
268-
# be x. So we can safely eliminate the following cat op.
269-
x1 = torch.ops.aten.cat((x, z))
270-
x2 = torch.add(x1, 2.4, 3.1)
271-
y1 = torch.add(y, 1, 2)
272-
return torch.add(x2, y1)
273-
274-
x = torch.ones(3, 5)
275-
y = torch.ones(3, 5)
276-
graph_module = export_to_edge(Cat(), (x, y)).exported_program().graph_module
277-
new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module
278-
new_graph_module.graph.eliminate_dead_code()
279-
# Assert that x1 is optimized away
280-
self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0)
281-
282-
def test_remove_zero_sized_cat(self):
283-
class Cat(torch.nn.Module):
284-
def __init__(self, dim: int):
285-
super().__init__()
286-
self.dim = dim
287-
288-
def forward(self, tensors):
289-
return torch.cat(tensors, self.dim)
290-
291-
shapes, dim, dtype, _max = [(1, 0, 3), (2, 0, 3)], 0, torch.float32, 127
292-
293-
in_tensors = [(torch.rand(shape) * _max).to(dtype=dtype) for shape in shapes]
294-
295-
model = Cat(dim)
296-
graph_module = (
297-
export_to_edge(model, (in_tensors,)).exported_program().graph_module
222+
builder = GraphBuilder()
223+
x = builder.placeholder("x", torch.randn([1, 0, 3, 5], dtype=torch.float32))
224+
y = builder.placeholder("y", torch.randn([2, 0, 3, 5], dtype=torch.float32))
225+
concat = builder.call_operator(
226+
op=exir_ops.edge.aten.cat.default, args=([x, y], 0)
227+
)
228+
builder.output([concat])
229+
original = builder.get_graph_module()
230+
graph_after_passes = cast(
231+
PassResult, RemoveZeroSizedCatArgsPass()(original)
232+
).graph_module
233+
self.assertEqual(
234+
count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0
298235
)
299-
new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module
300-
new_graph_module.graph.eliminate_dead_code()
301-
self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0)
302236

303237
def test_remove_clone(self):
304238
class Clone(torch.nn.Module):

backends/qualcomm/_passes/remove_redundancy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self):
2222
exir_ops.edge.aten.clone.default: self._default_condition,
2323
torch.ops.aten.alias.default: self._default_condition,
2424
exir_ops.edge.aten.alias.default: self._default_condition,
25+
exir_ops.edge.aten.alias_copy.default: self._default_condition,
2526
exir_ops.edge.aten.lift_fresh_copy.default: self._default_condition,
2627
# remove this target if '_skip_dim_order' is set to False
2728
exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition,

backends/qualcomm/tests/models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@
1010
# module with related operator only
1111

1212

13+
# Ensure alias_copy is removed in remove_redundancy pass
14+
class Alias(torch.nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
self.relu = torch.nn.ReLU()
18+
19+
def forward(self, x):
20+
alias_x = torch.ops.aten.alias.default(x)
21+
return self.relu(alias_x)
22+
23+
1324
class And(torch.nn.Module):
1425
def __init__(self, pos, neg):
1526
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ def test_qnn_backend_adaptive_avg_pool2d(self):
124124
sample_input = (torch.randn(1, 512, 7, 7),)
125125
self.lower_module_and_test_output(module, sample_input)
126126

127+
def test_qnn_backend_alias(self):
128+
module = Alias() # noqa: F405
129+
sample_input = (torch.randn(1, 10),)
130+
self.lower_module_and_test_output(module, sample_input)
131+
127132
def test_qnn_backend_amax(self):
128133
modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405
129134
sample_input = (torch.randn(4, 4),)
@@ -1162,6 +1167,12 @@ def test_qnn_backend_adaptive_avg_pool2d(self):
11621167
module = self.get_qdq_module(module, sample_input)
11631168
self.lower_module_and_test_output(module, sample_input)
11641169

1170+
def test_qnn_backend_alias(self):
1171+
module = Alias() # noqa: F405
1172+
sample_input = (torch.randn(1, 10),)
1173+
module = self.get_qdq_module(module, sample_input)
1174+
self.lower_module_and_test_output(module, sample_input)
1175+
11651176
def test_qnn_backend_amax(self):
11661177
modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405
11671178
sample_input = (torch.randn(4, 4),)

codegen/test/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain xplat-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()

0 commit comments

Comments
 (0)