Skip to content

Use GraphBuilder in unit tests for ops removal #2. #11011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 118 additions & 116 deletions backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,36 +235,28 @@ def test_remove_zero_arg_cat(self):
)

def test_remove_clone(self):
class Clone(torch.nn.Module):
def forward(self, x, y):
t1 = x.clone()
t2 = y.clone()
return t1 + t2

x = torch.ones(3, 5)
y = torch.ones(3, 5)
graph_module = export_to_edge(Clone(), (x, y)).exported_program().graph_module
new_graph_module = RemoveCloneOpPass()(graph_module).graph_module
new_graph_module.graph.eliminate_dead_code()
# Assert that t1 and t2 are optimized away
self.assertEqual(count_node(new_graph_module, torch.ops.aten.clone.out), 0)
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32))
clone = builder.call_operator(op=exir_ops.edge.aten.clone.default, args=(x,))
builder.output([clone])
original = builder.get_graph_module()
graph_after_passes = RemoveCloneOpPass()(original).graph_module
self.assertEqual(
count_node(graph_after_passes, torch.ops.aten.clone.default), 0
)

def test_remove_contiguous(self):
class Contiguous(torch.nn.Module):
def forward(self, x, y):
t1 = x.contiguous()
t2 = y.contiguous()
return t1 + t2

x = torch.ones(3, 5)
y = torch.ones(3, 5)
graph_module = (
export_to_edge(Contiguous(), (x, y)).exported_program().graph_module
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32))
contiguous = builder.call_operator(
op=exir_ops.edge.aten.contiguous.default, args=(x,)
)
builder.output([contiguous])
original = builder.get_graph_module()
graph_after_passes = RemoveContiguousOpPass()(original).graph_module
self.assertEqual(
count_node(graph_after_passes, torch.ops.aten.contiguous.default), 0
)
new_graph_module = RemoveContiguousOpPass()(graph_module).graph_module
new_graph_module.graph.eliminate_dead_code()
# Assert that t1 and t2 are optimized away
self.assertEqual(count_node(new_graph_module, torch.ops.aten.contiguous.out), 0)

@parameterized.expand(
[
Expand All @@ -274,119 +266,129 @@ def forward(self, x, y):
)
@torch.no_grad()
def test_remove_nop_view(self, shape, new_shape):
class View(torch.nn.Module):
def __init__(self, new_shape):
super().__init__()
self.new_shape = new_shape

def forward(self, x: torch.Tensor):
return x.view(self.new_shape)

model = View(new_shape)
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
p = RemoveNopSliceOrViewOpPass()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
graph_after_passes.graph.eliminate_dead_code()
# Assert that view op was removed
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
view = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default, args=(x, new_shape)
)
builder.output([view])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveNopSliceOrViewOpPass()(original)
).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0
)

def test_remove_nop_slice(self):
class Slice(torch.nn.Module):
def forward(self, x):
return torch.slice_copy(x, dim=0, start=0, step=1)

x = torch.ones(3, 5)
model = Slice()
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
p = RemoveNopSliceOrViewOpPass()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
graph_after_passes.graph.eliminate_dead_code()
# Assert that slice op was removed
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
slice_ = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(
x,
0, # dim
0, # start
3, # end
),
)
builder.output([slice_])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveNopSliceOrViewOpPass()(original)
).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
)

def test_remove_nop_select(self):
class SelectFeasible1(torch.nn.Module):
def forward(self, x):
y = x.select(0, 0)
z = y.view([1, 5, 6])
return z

x = torch.ones(1, 5, 6)
graph_module = (
export_to_edge(SelectFeasible1(), (x,)).exported_program().graph_module
def test_remove_nop_select_before_view(self):
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
select = builder.call_operator(
op=exir_ops.edge.aten.select_copy.int,
args=(
x,
0, # dim
0, # index
),
)
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
view = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default,
args=(select, [1, 5, 6]), # new shape
)
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
# Assert that select op was removed
builder.output([view])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveNopSelectOpPass()(original)
).graph_module
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
)

class SelectFeasible2(torch.nn.Module):
def forward(self, x, y):
x = x.select(0, 0)
z = x + y
return z

x = torch.ones(1, 5, 6)
y = torch.ones(1, 5, 6)
graph_module = (
export_to_edge(SelectFeasible2(), (x, y)).exported_program().graph_module
)
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
def test_remove_nop_select_before_add(self):
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32))
select = builder.call_operator(
op=exir_ops.edge.aten.select_copy.int,
args=(
x,
0, # dim
0, # index
),
)
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
# Assert that select op was removed
add = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(select, y))
builder.output([add])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveNopSelectOpPass()(original)
).graph_module
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
)

class SelectFeasible3(torch.nn.Module):
def forward(self, x, y):
x = x.select(0, 0)
z = x * y
return z

x = torch.ones(1, 5, 6)
y = torch.ones(1, 5, 6)
graph_module = (
export_to_edge(SelectFeasible3(), (x, y)).exported_program().graph_module
)
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
def test_remove_nop_select_before_mul(self):
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32))
select = builder.call_operator(
op=exir_ops.edge.aten.select_copy.int,
args=(
x,
0, # dim
0, # index
),
)
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
# Assert that select op was removed
mul = builder.call_operator(op=exir_ops.edge.aten.mul.Tensor, args=(select, y))
builder.output([mul])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveNopSelectOpPass()(original)
).graph_module
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
)

class SelectFeasible4(torch.nn.Module):
def forward(self, x, y):
x = x.select(0, 0)
z = x / y
return z

x = torch.ones(1, 5, 6)
y = torch.ones(1, 5, 6)
graph_module = (
export_to_edge(SelectFeasible4(), (x, y)).exported_program().graph_module
)
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
def test_remove_nop_select_before_div(self):
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32))
select = builder.call_operator(
op=exir_ops.edge.aten.select_copy.int,
args=(
x,
0, # dim
0, # index
),
)
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
# Assert that select op was removed
div = builder.call_operator(op=exir_ops.edge.aten.div.Tensor, args=(select, y))
builder.output([div])
original = builder.get_graph_module()
graph_after_passes = cast(
PassResult, RemoveNopSelectOpPass()(original)
).graph_module
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
)

def test_remove_nop_quant_dequant(self):
Expand Down
Loading