Skip to content

Use GraphBuilder in test_replace_ops_passes. #1 #11344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 118 additions & 147 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
from executorch.exir.passes import dead_code_elimination_pass

from parameterized.parameterized import parameterized
from torch._ops import OpOverload
from torch.fx.passes.infra.pass_base import PassResult


Expand Down Expand Up @@ -87,36 +86,46 @@ def assertTargetCountsEqual(

@parameterized.expand(
[
# Regular MM
[(64, 33), (33, 128)],
# Batched MM
[(2, 48, 48), (2, 48, 48)],
]
(
"regular",
(64, 33), # x_shape
(33, 128), # y_shape
),
(
"batched",
(2, 48, 48), # x_shape
(2, 48, 48), # y_shape
),
],
)
@torch.no_grad()
def test_replace_matmul_with_transposed_matmul(
self,
_,
x_shape: Tuple[int],
y_shape: Tuple[int],
) -> None:
class MatMul(torch.nn.Module):
def __init__(self) -> None:
super(MatMul, self).__init__()

def forward(self, x, y):
return torch.matmul(x, y)

model = MatMul()
X = torch.randn(x_shape)
Y = torch.randn(y_shape)
p = ReplaceMatmulWithTransposedMatmulPass()
inputs = (X, Y)
graph_module = (
quantize_and_export_to_edge(model, inputs).exported_program().graph_module
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*x_shape, dtype=torch.float32))
y = builder.placeholder("y", torch.randn(*y_shape, dtype=torch.float32))
matmul = builder.call_operator(
op=exir_ops.edge.cadence.quantized_matmul.default,
args=(
x,
0, # X_zero_point
y,
0, # Y_zero_point,
None, # bias
1, # out_multiplier
0, # out_shift
0, # out_zero_point
False, # transposed=False
),
)
# pyre-fixme[16]: Optional type has no attribute `graph_module`
graph_after_passes = p(graph_module).graph_module

builder.output([matmul])
original = builder.get_graph_module()
p = ReplaceMatmulWithTransposedMatmulPass()
graph_after_passes = cast(PassResult, p(original)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int),
1,
Expand All @@ -130,33 +139,24 @@ def forward(self, x, y):

@parameterized.expand(
[
[(3, 5), (0, 0)],
[
(20, 1, 80),
(0, 0),
],
]
("2d", (3, 5), [0, 0]), # shape # padding
("3d", (20, 1, 80), [0, 0, 0]), # shape # padding
],
)
@torch.no_grad()
def test_replace_constant_pad_nd_with_slice(
self, shape: Tuple[int], padding: Tuple[int]
self, _, shape: Tuple[int], padding: Tuple[int]
):
# F.pad is converted to aten::constant_pad_nd after functionalization & decomposition.
class Padding(torch.nn.Module):
def __init__(self):
super().__init__()
self.padding = padding

def forward(self, x: torch.Tensor):
return F.pad(x, self.padding)

model = Padding()
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module

builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
matmul = builder.call_operator(
op=exir_ops.edge.aten.constant_pad_nd.default,
args=(x, [0, 0, 0, 0]),
)
builder.output([matmul])
original = builder.get_graph_module()
p = ReplaceConstantPadNdWithSlicePass()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module
graph_after_passes = cast(PassResult, p(original)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.slice.Tensor),
1,
Expand All @@ -169,142 +169,140 @@ def forward(self, x: torch.Tensor):

@parameterized.expand(
[
[(7, 5, 6), 1.23],
[(7, 5), 2],
["3d", (7, 5, 6), 1.23],
["2d", (7, 5), 2],
["1d", (10,), 42949],
]
)
@torch.no_grad()
def test_add_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float):
class Add(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.add.Scalar(x, other)

model = Add()
def test_add_replace_scalar_with_tensor_arg(
self, _, shape: Tuple[int], other: float
):
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module

original = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.aten.add.Scalar,
args=(x, other),
)
p = ReplaceScalarWithTensorArgPass()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module
graph_after_passes = cast(PassResult, p(original)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
1,
)

self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.add.Scalar),
0,
)

@parameterized.expand(
[
[(7, 5, 6), 1.23],
[(7, 5), 2],
[(10), 42949],
["3d", (7, 5, 6), 1.23],
["2d", (7, 5), 2],
["1d", (10,), 42949],
]
)
@torch.no_grad()
def test_sub_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float):
class Sub(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.sub.Scalar(x, other)

model = Sub()
def test_sub_replace_scalar_with_tensor_arg(
self, _, shape: Tuple[int], other: float
):
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module

original = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.aten.sub.Scalar,
args=(x, other),
)
p = ReplaceScalarWithTensorArgPass()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module
graph_after_passes = cast(PassResult, p(original)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.sub.Tensor),
1,
)

self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.sub.Scalar),
0,
)

@parameterized.expand(
[
[(7, 5, 6), 1.23],
[(7, 5), 2],
[(513), 3],
["3d", (7, 5, 6), 1.23],
["2d", (7, 5), 2],
["1d", (10,), 42949],
]
)
@torch.no_grad()
def test_mul_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float):
class Mul(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.mul.Scalar(x, other)

model = Mul()
def test_mul_replace_scalar_with_tensor_arg(
self, _, shape: Tuple[int], other: float
):
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module

original = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.aten.mul.Scalar,
args=(x, other),
)
p = ReplaceScalarWithTensorArgPass()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module
graph_after_passes = cast(PassResult, p(original)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
1,
)

self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.mul.Scalar),
0,
)

@parameterized.expand(
[
[(7, 5, 6), 1.23],
[(7, 5), 2],
["3d", (7, 5, 6), 1.23],
["2d", (7, 5), 2],
["1d", (10,), 42949],
]
)
@torch.no_grad()
def test_div_replace_scalar_with_tensor_arg(
self,
_,
shape: Tuple[int],
other: float,
):
class Div(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.div.Scalar(x, other)

model = Div()
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module

x = torch.randn(*shape)
original = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.aten.div.Scalar,
args=(x, other),
)
p = ReplaceScalarWithTensorArgPass()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module
graph_after_passes = cast(PassResult, p(original)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.div.Tensor),
1,
)

self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.div.Scalar),
0,
)

@parameterized.expand(
[
[(2, 3, 5, 6)],
[(7, 6, 5)],
[(4, 4)],
[(316)],
["4d", (2, 3, 5, 6)],
["3d", (7, 6, 5)],
["2d", (4, 4)],
["1d", (316)],
]
)
@torch.no_grad()
def test_replace_functionally_equivalent_op_targets_relu(self, shape: Tuple[int]):
model = torch.nn.ReLU()
def test_replace_functionally_equivalent_op_targets_relu(
self, _, shape: Tuple[int]
):
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
original = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.aten.relu_.default,
args=(x,),
)
p = ReplaceFunctionallyEquivalentOpTargets()
graph_after_passes = cast(PassResult, p(original)).graph_module

graph_after_passes = cast(PassResult, p(graph_module)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.relu.default),
1,
Expand All @@ -315,56 +313,29 @@ def test_replace_functionally_equivalent_op_targets_relu(self, shape: Tuple[int]
)

@parameterized.expand(
[
# split the only dimension
[(50,), i, 0]
for i in range(2, 7)
]
+ [
# split the leading dim
[(10, 2, 3), i, 0]
for i in range(2, 7)
]
+ [
# split the trailing dim
[(3, 3, 6), i, 2]
for i in range(2, 6)
]
+ [
# split the dim in the middle
[(3, 5, 14, 2, 3), i, 2]
for i in range(2, 7)
]
[["split_linear_tensor", (50,), i, 0] for i in range(2, 7)]
+ [["split_leading_dim", (10, 2, 3), i, 0] for i in range(2, 7)]
+ [["split_trailing_dim", (3, 3, 6), i, 2] for i in range(2, 6)]
+ [["split_middle_dim", (3, 5, 14, 2, 3), i, 2] for i in range(2, 7)]
)
@torch.no_grad()
def test_replace_functionally_equivalent_op_targets_unsafe_split(
self, shape: Tuple[int], split_size: int, dim: int
self, _, shape: Tuple[int], split_size: int, dim: int
):
class TensorSplitWithSizes(torch.nn.Module):
def __init__(self, split_size: int, dim: int, op: OpOverload):
super().__init__()
self.split_size = split_size
self.dim = dim
self.op = op

def forward(self, x: torch.Tensor):
return self.op(x, self.split_size, self.dim)

x = torch.randn(shape)
model = TensorSplitWithSizes(split_size, dim, torch.unsafe_split)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
original = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.aten.unsafe_split.Tensor,
args=(x, split_size, dim),
)
p = ReplaceFunctionallyEquivalentOpTargets()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module
graph_after_passes = cast(PassResult, p(original)).graph_module
self.assertEqual(
count_node(
graph_after_passes, exir_ops.edge.aten.split_with_sizes_copy.default
),
count_node(graph_after_passes, exir_ops.edge.aten.split_copy.Tensor),
1,
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor),
0,
count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), 0, x
)

@parameterized.expand(
Expand Down
Loading