Skip to content

Commit ea8db06

Browse files
authored
Use GraphBuilder in unit tests for ops removal #2.
Differential Revision: D75034439 Pull Request resolved: #11011
1 parent 77e342d commit ea8db06

File tree

1 file changed

+118
-116
lines changed

1 file changed

+118
-116
lines changed

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 118 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -235,36 +235,28 @@ def test_remove_zero_arg_cat(self):
235235
)
236236

237237
def test_remove_clone(self):
238-
class Clone(torch.nn.Module):
239-
def forward(self, x, y):
240-
t1 = x.clone()
241-
t2 = y.clone()
242-
return t1 + t2
243-
244-
x = torch.ones(3, 5)
245-
y = torch.ones(3, 5)
246-
graph_module = export_to_edge(Clone(), (x, y)).exported_program().graph_module
247-
new_graph_module = RemoveCloneOpPass()(graph_module).graph_module
248-
new_graph_module.graph.eliminate_dead_code()
249-
# Assert that t1 and t2 are optimized away
250-
self.assertEqual(count_node(new_graph_module, torch.ops.aten.clone.out), 0)
238+
builder = GraphBuilder()
239+
x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32))
240+
clone = builder.call_operator(op=exir_ops.edge.aten.clone.default, args=(x,))
241+
builder.output([clone])
242+
original = builder.get_graph_module()
243+
graph_after_passes = RemoveCloneOpPass()(original).graph_module
244+
self.assertEqual(
245+
count_node(graph_after_passes, torch.ops.aten.clone.default), 0
246+
)
251247

252248
def test_remove_contiguous(self):
253-
class Contiguous(torch.nn.Module):
254-
def forward(self, x, y):
255-
t1 = x.contiguous()
256-
t2 = y.contiguous()
257-
return t1 + t2
258-
259-
x = torch.ones(3, 5)
260-
y = torch.ones(3, 5)
261-
graph_module = (
262-
export_to_edge(Contiguous(), (x, y)).exported_program().graph_module
249+
builder = GraphBuilder()
250+
x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32))
251+
contiguous = builder.call_operator(
252+
op=exir_ops.edge.aten.contiguous.default, args=(x,)
253+
)
254+
builder.output([contiguous])
255+
original = builder.get_graph_module()
256+
graph_after_passes = RemoveContiguousOpPass()(original).graph_module
257+
self.assertEqual(
258+
count_node(graph_after_passes, torch.ops.aten.contiguous.default), 0
263259
)
264-
new_graph_module = RemoveContiguousOpPass()(graph_module).graph_module
265-
new_graph_module.graph.eliminate_dead_code()
266-
# Assert that t1 and t2 are optimized away
267-
self.assertEqual(count_node(new_graph_module, torch.ops.aten.contiguous.out), 0)
268260

269261
@parameterized.expand(
270262
[
@@ -274,119 +266,129 @@ def forward(self, x, y):
274266
)
275267
@torch.no_grad()
276268
def test_remove_nop_view(self, shape, new_shape):
277-
class View(torch.nn.Module):
278-
def __init__(self, new_shape):
279-
super().__init__()
280-
self.new_shape = new_shape
281-
282-
def forward(self, x: torch.Tensor):
283-
return x.view(self.new_shape)
284-
285-
model = View(new_shape)
286-
x = torch.randn(shape)
287-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
288-
p = RemoveNopSliceOrViewOpPass()
289-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
290-
graph_after_passes.graph.eliminate_dead_code()
291-
# Assert that view op was removed
269+
builder = GraphBuilder()
270+
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
271+
view = builder.call_operator(
272+
op=exir_ops.edge.aten.view_copy.default, args=(x, new_shape)
273+
)
274+
builder.output([view])
275+
original = builder.get_graph_module()
276+
graph_after_passes = cast(
277+
PassResult, RemoveNopSliceOrViewOpPass()(original)
278+
).graph_module
292279
self.assertEqual(
293280
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0
294281
)
295282

296283
def test_remove_nop_slice(self):
297-
class Slice(torch.nn.Module):
298-
def forward(self, x):
299-
return torch.slice_copy(x, dim=0, start=0, step=1)
300-
301-
x = torch.ones(3, 5)
302-
model = Slice()
303-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
304-
p = RemoveNopSliceOrViewOpPass()
305-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
306-
graph_after_passes.graph.eliminate_dead_code()
307-
# Assert that slice op was removed
284+
builder = GraphBuilder()
285+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
286+
slice_ = builder.call_operator(
287+
op=exir_ops.edge.aten.slice_copy.Tensor,
288+
args=(
289+
x,
290+
0, # dim
291+
0, # start
292+
3, # end
293+
),
294+
)
295+
builder.output([slice_])
296+
original = builder.get_graph_module()
297+
graph_after_passes = cast(
298+
PassResult, RemoveNopSliceOrViewOpPass()(original)
299+
).graph_module
308300
self.assertEqual(
309301
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
310302
)
311303

312-
def test_remove_nop_select(self):
313-
class SelectFeasible1(torch.nn.Module):
314-
def forward(self, x):
315-
y = x.select(0, 0)
316-
z = y.view([1, 5, 6])
317-
return z
318-
319-
x = torch.ones(1, 5, 6)
320-
graph_module = (
321-
export_to_edge(SelectFeasible1(), (x,)).exported_program().graph_module
304+
def test_remove_nop_select_before_view(self):
305+
builder = GraphBuilder()
306+
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
307+
select = builder.call_operator(
308+
op=exir_ops.edge.aten.select_copy.int,
309+
args=(
310+
x,
311+
0, # dim
312+
0, # index
313+
),
322314
)
323-
self.assertEqual(
324-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
315+
view = builder.call_operator(
316+
op=exir_ops.edge.aten.view_copy.default,
317+
args=(select, [1, 5, 6]), # new shape
325318
)
326-
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
327-
# Assert that select op was removed
319+
builder.output([view])
320+
original = builder.get_graph_module()
321+
graph_after_passes = cast(
322+
PassResult, RemoveNopSelectOpPass()(original)
323+
).graph_module
328324
self.assertEqual(
329-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
325+
count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
330326
)
331327

332-
class SelectFeasible2(torch.nn.Module):
333-
def forward(self, x, y):
334-
x = x.select(0, 0)
335-
z = x + y
336-
return z
337-
338-
x = torch.ones(1, 5, 6)
339-
y = torch.ones(1, 5, 6)
340-
graph_module = (
341-
export_to_edge(SelectFeasible2(), (x, y)).exported_program().graph_module
342-
)
343-
self.assertEqual(
344-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
328+
def test_remove_nop_select_before_add(self):
329+
builder = GraphBuilder()
330+
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
331+
y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32))
332+
select = builder.call_operator(
333+
op=exir_ops.edge.aten.select_copy.int,
334+
args=(
335+
x,
336+
0, # dim
337+
0, # index
338+
),
345339
)
346-
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
347-
# Assert that select op was removed
340+
add = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(select, y))
341+
builder.output([add])
342+
original = builder.get_graph_module()
343+
graph_after_passes = cast(
344+
PassResult, RemoveNopSelectOpPass()(original)
345+
).graph_module
348346
self.assertEqual(
349-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
347+
count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
350348
)
351349

352-
class SelectFeasible3(torch.nn.Module):
353-
def forward(self, x, y):
354-
x = x.select(0, 0)
355-
z = x * y
356-
return z
357-
358-
x = torch.ones(1, 5, 6)
359-
y = torch.ones(1, 5, 6)
360-
graph_module = (
361-
export_to_edge(SelectFeasible3(), (x, y)).exported_program().graph_module
362-
)
363-
self.assertEqual(
364-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
350+
def test_remove_nop_select_before_mul(self):
351+
builder = GraphBuilder()
352+
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
353+
y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32))
354+
select = builder.call_operator(
355+
op=exir_ops.edge.aten.select_copy.int,
356+
args=(
357+
x,
358+
0, # dim
359+
0, # index
360+
),
365361
)
366-
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
367-
# Assert that select op was removed
362+
mul = builder.call_operator(op=exir_ops.edge.aten.mul.Tensor, args=(select, y))
363+
builder.output([mul])
364+
original = builder.get_graph_module()
365+
graph_after_passes = cast(
366+
PassResult, RemoveNopSelectOpPass()(original)
367+
).graph_module
368368
self.assertEqual(
369-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
369+
count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
370370
)
371371

372-
class SelectFeasible4(torch.nn.Module):
373-
def forward(self, x, y):
374-
x = x.select(0, 0)
375-
z = x / y
376-
return z
377-
378-
x = torch.ones(1, 5, 6)
379-
y = torch.ones(1, 5, 6)
380-
graph_module = (
381-
export_to_edge(SelectFeasible4(), (x, y)).exported_program().graph_module
382-
)
383-
self.assertEqual(
384-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
372+
def test_remove_nop_select_before_div(self):
373+
builder = GraphBuilder()
374+
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
375+
y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32))
376+
select = builder.call_operator(
377+
op=exir_ops.edge.aten.select_copy.int,
378+
args=(
379+
x,
380+
0, # dim
381+
0, # index
382+
),
385383
)
386-
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
387-
# Assert that select op was removed
384+
div = builder.call_operator(op=exir_ops.edge.aten.div.Tensor, args=(select, y))
385+
builder.output([div])
386+
original = builder.get_graph_module()
387+
graph_after_passes = cast(
388+
PassResult, RemoveNopSelectOpPass()(original)
389+
).graph_module
388390
self.assertEqual(
389-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
391+
count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
390392
)
391393

392394
def test_remove_nop_quant_dequant(self):

0 commit comments

Comments
 (0)