Skip to content

Commit 6bae6d5

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Use single_op_builder in simplify unit tests. (#11158)
Summary: Use single_op_builder in simplify unit tests. Reviewed By: hsharma35 Differential Revision: D75309572
1 parent bc47f5a commit 6bae6d5

File tree

1 file changed

+30
-80
lines changed

1 file changed

+30
-80
lines changed

backends/cadence/aot/tests/test_simplify_ops_passes.py

Lines changed: 30 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
15-
from executorch.backends.cadence.aot.compiler import export_to_edge
1615
from executorch.backends.cadence.aot.graph_builder import single_op_builder
1716
from executorch.backends.cadence.aot.pass_utils import count_node
1817
from executorch.backends.cadence.aot.simplify_ops import (
@@ -40,82 +39,47 @@ def test_simplify_slice_scatter_op(
4039
end: Optional[int] = None,
4140
step: int = 1,
4241
):
43-
class SliceScatter(torch.nn.Module):
44-
def __init__(
45-
self, dim: int, start: Optional[int], end: Optional[int], step: int
46-
):
47-
super().__init__()
48-
self.dim = dim
49-
self.start = start
50-
self.end = end
51-
self.step = step
52-
53-
def forward(self, x: torch.Tensor, y: torch.Tensor):
54-
return torch.slice_scatter(
55-
x, y, self.dim, self.start, self.end, self.step
56-
)
57-
58-
model = SliceScatter(dim, start, end, step)
59-
x = torch.randn(in_shape)
60-
y = torch.randn(src_shape)
61-
graph_module = export_to_edge(model, (x, y)).exported_program().graph_module
62-
63-
p = SimplifySliceOpPass()
64-
65-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
66-
67-
self.assertEqual(
68-
count_node(graph_after_passes, exir_ops.edge.aten.slice_scatter.default), 0
42+
x = torch.randn(*in_shape)
43+
y = torch.randn(*src_shape)
44+
gm = single_op_builder(
45+
placeholders=(x, y),
46+
op=exir_ops.edge.aten.slice_scatter.default,
47+
args=(x, y, dim, start, end, step),
6948
)
49+
p = SimplifySliceOpPass()
50+
gm = cast(PassResult, p(gm)).graph_module
51+
self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_scatter.default), 0)
7052

7153
@parameterized.expand(
7254
[
73-
[(3, 16, 5), (3, 0, 5), 1, 15, 3, 3],
55+
[(3, 16, 5), 1, 15, 3, 3],
7456
]
7557
)
7658
@torch.no_grad()
7759
def test_simplify_slice_op(
7860
self,
7961
in_shape: Tuple[int],
80-
src_shape: Tuple[int],
8162
dim: int,
8263
start: Optional[int] = None,
8364
end: Optional[int] = None,
8465
step: int = 1,
8566
):
86-
class SliceCopy(torch.nn.Module):
87-
def __init__(
88-
self, dim: int, start: Optional[int], end: Optional[int], step: int
89-
):
90-
super().__init__()
91-
self.dim = dim
92-
self.start = start
93-
self.end = end
94-
self.step = step
95-
96-
def forward(self, x: torch.Tensor) -> torch.Tensor:
97-
return torch.slice_copy(
98-
x, dim=self.dim, start=self.start, end=self.end, step=self.step
99-
)
100-
101-
# Create a model with single slice copy op.
102-
model = SliceCopy(dim, start, end, step)
103-
x = torch.randn(in_shape)
104-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
105-
self.assertEqual(
106-
count_node(graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
67+
x = torch.randn(*in_shape)
68+
gm = single_op_builder(
69+
placeholders=(x,),
70+
op=exir_ops.edge.aten.slice_copy.Tensor,
71+
args=(
72+
x,
73+
dim,
74+
start,
75+
end,
76+
step,
77+
),
10778
)
108-
10979
p = SimplifySliceOpPass()
110-
111-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
112-
113-
self.assertEqual(
114-
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
115-
)
116-
self.assertEqual(
117-
count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1
118-
)
80+
gm = cast(PassResult, p(gm)).graph_module
81+
self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 0)
82+
self.assertEqual(count_node(gm, exir_ops.edge.aten.full.default), 1)
11983

12084
def test_simplify_slice_op_args(self) -> None:
12185
x = torch.rand(4, 5)
@@ -125,24 +89,10 @@ def test_simplify_slice_op_args(self) -> None:
12589
args=(x, 1),
12690
kwargs={"end": 3},
12791
)
128-
self.assertEqual(
129-
[
130-
(n.args[1:], n.kwargs)
131-
for n in gm.graph.find_nodes(
132-
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
133-
)
134-
],
135-
[((1,), {"end": 3})],
136-
)
137-
92+
original_slice_copy = list(gm.graph.nodes)[1]
93+
self.assertEqual(original_slice_copy.args[1:], (1,))
94+
self.assertEqual(original_slice_copy.kwargs, {"end": 3})
13895
gm = BindOptionalArgsPass().call(gm).graph_module
139-
140-
self.assertEqual(
141-
[
142-
(n.args[1:], n.kwargs)
143-
for n in gm.graph.find_nodes(
144-
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
145-
)
146-
],
147-
[((1, None, 3, 1), {})],
148-
)
96+
modified_slice_copy = list(gm.graph.nodes)[1]
97+
self.assertEqual(modified_slice_copy.args[1:], (1, None, 3, 1))
98+
self.assertEqual(modified_slice_copy.kwargs, {})

0 commit comments

Comments
 (0)