Skip to content

Commit 6f44a79

Browse files
authored
Arm backend: Adjust pooling input when not divisible by stride (#11854)
* Rename SizeAdjustConv2DPass to SizeAdjustInputPass * Add support for adjusting input size for pooling operators
1 parent 7ede6be commit 6f44a79

File tree

6 files changed

+182
-47
lines changed

6 files changed

+182
-47
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
ReplaceScalarWithTensorArgPassTOSAMI,
6868
)
6969
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
70-
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
70+
from .size_adjust_input_pass import SizeAdjustInputPass # noqa
7171
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
7272
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
7373
from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
ReplaceScalarWithTensorArgPassTOSAMI,
6565
RetraceFoldedDtypesPass,
6666
ScalarsToAttributePass,
67-
SizeAdjustConv2DPass,
67+
SizeAdjustInputPass,
6868
UnsqueezeBeforeRepeatPass,
6969
UnsqueezeScalarPlaceholdersPass,
7070
)
@@ -125,13 +125,13 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
125125

126126
self.add_pass(DecomposeGroupedConv())
127127
self.add_pass(RemoveClonePass())
128-
self.add_pass(SizeAdjustConv2DPass())
129128
self.add_pass(ConvertExpandCopyToRepeatPass())
130129
self.add_pass(UnsqueezeBeforeRepeatPass())
131130
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
132131
self.add_pass(DecomposeSumPass())
133132
self.add_pass(Conv1dUnsqueezePass())
134133
self.add_pass(DecomposeMaxPool2DPass())
134+
self.add_pass(SizeAdjustInputPass())
135135
self.add_pass(DecomposeSelectPass())
136136
self.add_pass(ConvertSqueezesToViewPass())
137137

@@ -187,13 +187,13 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
187187

188188
self.add_pass(DecomposeGroupedConv())
189189
self.add_pass(RemoveClonePass())
190-
self.add_pass(SizeAdjustConv2DPass())
191190
self.add_pass(ConvertExpandCopyToRepeatPass())
192191
self.add_pass(UnsqueezeBeforeRepeatPass())
193192
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
194193
self.add_pass(DecomposeSumPass())
195194
self.add_pass(Conv1dUnsqueezePass())
196195
self.add_pass(DecomposeMaxPool2DPass())
196+
self.add_pass(SizeAdjustInputPass())
197197
self.add_pass(DecomposeSelectPass())
198198
self.add_pass(ConvertSqueezesToViewPass())
199199

Lines changed: 134 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,149 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

9-
from typing import cast
8+
from typing import cast, TypeAlias
109

1110
import torch.fx
1211
from executorch.backends.arm._passes.arm_pass_utils import create_node
1312
from executorch.exir.dialects._ops import ops as exir_ops
1413
from executorch.exir.pass_base import ExportPass, PassResult
1514

15+
Slices: TypeAlias = list[tuple[int, int, int]]
1616

17-
def conv_remainder(input_length, pad, dilation, weight, stride):
17+
conv2d_op = exir_ops.edge.aten.convolution.default
18+
max_pooling_op = exir_ops.edge.aten.max_pool2d.default
19+
avg_pooling_op = exir_ops.edge.aten.avg_pool2d.default
20+
slice_op = exir_ops.edge.aten.slice_copy.Tensor
21+
22+
valid_operators = [conv2d_op, max_pooling_op, avg_pooling_op]
23+
24+
25+
def conv_remainder(input_length, pad, dilation, weight, stride) -> int:
1826
"""
1927
Returns the remainder of input_length; given the padding, dilation, stride,
2028
and kernel size.
2129
"""
2230
return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride
2331

2432

25-
class SizeAdjustConv2DPass(ExportPass):
33+
def pooling_remainder(input_size, pad, kernel_size, stride) -> int:
34+
"""
35+
Returns the remainder of input_length; given the padding, stride, and
36+
kernel size.
37+
"""
38+
return (input_size + 2 * pad - kernel_size) % stride
39+
40+
41+
def get_slices_conv2d(conv_node: torch.fx.Node) -> Slices:
42+
slices = []
43+
44+
input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = conv_node.args
45+
weight_shape = cast(torch.fx.Node, weight).meta["val"].shape
46+
input_shape = cast(torch.fx.Node, input_node).meta["val"].shape
47+
48+
for stride, pad, dilation, dim in zip(
49+
cast(list, stride_hw),
50+
cast(list, pad_hw),
51+
cast(list, dilation_hw),
52+
(2, 3),
53+
):
54+
remainder = conv_remainder(
55+
input_shape[dim], pad, dilation, weight_shape[dim], stride
56+
)
57+
if remainder > pad:
58+
adjustment = remainder - pad
59+
args = (dim, 0, input_shape[dim] - adjustment)
60+
slices.append(args)
61+
62+
return slices
63+
64+
65+
def get_slices_pooling(pooling_node: torch.fx.Node) -> Slices:
66+
slices = []
67+
68+
input_node = pooling_node.args[0]
69+
kernel_size = pooling_node.args[1]
70+
stride = pooling_node.args[2]
71+
padding = pooling_node.args[3] if len(pooling_node.args) >= 4 else [0, 0]
72+
73+
# For the loop below, padding must be a list
74+
if isinstance(padding, int):
75+
padding = [padding, padding]
76+
77+
input_shape = cast(torch.fx.Node, input_node).meta["val"].shape
78+
79+
for kernel_length, stride_length, pad_size, dim in zip(
80+
cast(list, kernel_size),
81+
cast(list, stride),
82+
cast(list, padding),
83+
(2, 3),
84+
):
85+
remainder = pooling_remainder(
86+
input_shape[dim], pad_size, kernel_length, stride_length
87+
)
88+
if remainder > pad_size:
89+
adjustment = remainder - pad_size
90+
args = (dim, 0, input_shape[dim] - adjustment)
91+
slices.append(args)
92+
93+
return slices
94+
95+
96+
def get_slices(node: torch.fx.Node) -> Slices:
97+
"""
98+
Returns the remainder of input_length; given graph Node.
99+
"""
100+
if node.target == conv2d_op:
101+
return get_slices_conv2d(node)
102+
elif node.target == max_pooling_op or node.target == avg_pooling_op:
103+
return get_slices_pooling(node)
104+
else:
105+
raise ValueError(f"Unsupported node target, was expecting {valid_operators}")
106+
107+
108+
def is_valid_operator(node: torch.fx.Node) -> bool:
109+
if node.target == conv2d_op:
110+
return True
111+
elif node.target == max_pooling_op:
112+
dilation = node.args[4] if len(node.args) >= 5 else 1
113+
ceil_mode = node.args[5] if len(node.args) >= 6 else False
114+
115+
# Dilation should be handled first by DecomposeMaxPool2DPass
116+
if isinstance(dilation, int):
117+
if dilation > 1:
118+
raise ValueError(
119+
"Expected max_pool2d with dilation = 1, has DecomposeMaxPool2DPass been run?"
120+
)
121+
else:
122+
dilation = cast(list, dilation)
123+
if dilation[0] > 1 or dilation[1] > 1:
124+
raise ValueError(
125+
"Expected max_pool2d with dilation = [1, 1], has DecomposeMaxPool2DPass been run?"
126+
)
127+
128+
# If using ceil mode for rounding, the input does not need adjusting
129+
return not ceil_mode
130+
elif node.target == avg_pooling_op:
131+
ceil_mode = node.args[4] if len(node.args) >= 5 else False
132+
count_include_pad = node.args[5] if len(node.args) >= 6 else True
133+
divisor_override = node.args[6] if len(node.args) >= 7 else None
134+
135+
return not ceil_mode and not count_include_pad and divisor_override is None
136+
137+
return False
138+
139+
140+
class SizeAdjustInputPass(ExportPass):
26141
"""
27-
Adjust the convolution input size to match the kernel size, padding, stride,
28-
and dilation parameters. Pytorch allows the input and kernel shape to not
29-
"match", in which case the remaining rows/columns are truncated. However,
30-
matching the size is a requirement in the TOSA specification. In case the
31-
input and kernel shape do not match, the following is done to meet the
32-
specification:
142+
Adjusts the input size to Conv2D and Pooling operators. PyTorch allows
143+
the input and kernel shape to not "match", in which case the remaining
144+
rows/columns are truncated. However, matching the size is a requirement
145+
in the TOSA specification. In case the input and kernel shape do not
146+
match, the following is performed to meet the specification:
33147
34148
1) The padding is truncated (done in the node visitor)
35149
2) (if neccessary) The input is truncated (done in this pass)."
@@ -71,52 +185,33 @@ class SizeAdjustConv2DPass(ExportPass):
71185
input.
72186
"""
73187

74-
conv2d_op = exir_ops.edge.aten.convolution.default
75-
slice_op = exir_ops.edge.aten.slice_copy.Tensor
76-
77-
def call(self, graph_module: torch.fx.GraphModule):
188+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
78189
graph = graph_module.graph
79190
modified_graph = False
80191
for node in graph.nodes:
81192
if node.op != "call_function":
82193
continue
83-
if node.target != self.conv2d_op:
194+
if not is_valid_operator(node):
84195
continue
85196

86-
conv_node = cast(torch.fx.Node, node)
87-
input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = (
88-
conv_node.args
89-
)
90-
weight_shape = cast(torch.fx.Node, weight).meta["val"].shape
91-
input_shape = cast(torch.fx.Node, input_node).meta["val"].shape
92-
93-
slice_args = []
94-
for stride, pad, dilation, dim in zip(
95-
cast(list, stride_hw),
96-
cast(list, pad_hw),
97-
cast(list, dilation_hw),
98-
(2, 3),
99-
):
100-
remainder = conv_remainder(
101-
input_shape[dim], pad, dilation, weight_shape[dim], stride
102-
)
103-
if remainder > pad:
104-
adjustment = remainder - pad
105-
args = (dim, 0, input_shape[dim] - adjustment)
106-
slice_args.append(args)
197+
target_node = cast(torch.fx.Node, node)
198+
slice_args = get_slices(target_node)
199+
107200
if len(slice_args) == 0:
108201
continue
109202

203+
parent_node = node.args[0]
110204
with graph_module.graph.inserting_before(node):
111-
last_node = cast(torch.fx.Node, input_node)
205+
last_node = cast(torch.fx.Node, parent_node)
112206
for args in slice_args:
113-
slice_node = create_node(graph, self.slice_op, (last_node,) + args)
207+
slice_node = create_node(graph, slice_op, (last_node,) + args)
114208
last_node = slice_node
115-
conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node)
209+
node.replace_input_with(cast(torch.fx.Node, parent_node), last_node)
116210
modified_graph = True
117211

118212
if modified_graph:
119213
graph_module = super().call(graph_module).graph_module
120214
graph.eliminate_dead_code()
121215
graph_module.recompile()
216+
122217
return PassResult(graph_module, True)

backends/arm/test/models/test_nn_functional.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def forward(self, *args):
8282
"test_data",
8383
module_tests,
8484
xfails={
85-
"max_pool1d": "ValueError: Invalid TOSA graph",
8685
"affine_grid": "Int64 input. Partition handling fails since arange int64 output is split between 2 partitions.",
8786
},
8887
)

backends/arm/test/ops/test_avg_pool2d.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def forward(self, *args, **kwargs):
5151
AvgPool2d((4, 6), (1, 2), (2, 3)),
5252
(torch.rand(1, 16, 50, 32),),
5353
),
54-
"non_divisible_window": lambda: (
54+
"non_divisible_window_adjust_padding": lambda: (
5555
AvgPool2d(3, 2, 1, count_include_pad=False),
5656
(torch.rand(1, 16, 112, 112),),
5757
),
58-
"non_divisible_window_height": lambda: (
58+
"non_divisible_window_adjust_padding_height": lambda: (
5959
AvgPool2d(3, (2, 1), 1),
6060
(torch.rand(1, 16, 56, 56),),
6161
),
62-
"non_divisible_window_width": lambda: (
62+
"non_divisible_window_adjust_padding_width": lambda: (
6363
AvgPool2d(3, (1, 2), 1, count_include_pad=False),
6464
(torch.rand(1, 16, 56, 56),),
6565
),
@@ -91,6 +91,22 @@ def forward(self, *args, **kwargs):
9191
AvgPool2d(3, 2, 1, True, True, divisor_override=2),
9292
(torch.rand(1, 1, 14, 14),),
9393
),
94+
"non_divisible_no_padding": lambda: (
95+
AvgPool2d(3, 2, 0),
96+
(torch.rand(1, 16, 56, 56),),
97+
),
98+
"non_divibile_window_adjust_padding+input": lambda: (
99+
AvgPool2d(3, 3, 1, count_include_pad=False),
100+
(torch.rand(1, 16, 54, 54),),
101+
),
102+
"non_divibile_window_height_adjust_padding+input": lambda: (
103+
AvgPool2d(3, (3, 1), 1),
104+
(torch.rand(1, 16, 54, 54),),
105+
),
106+
"non_divibile_window_width_adjust_padding+input": lambda: (
107+
AvgPool2d(3, (1, 3), 1, count_include_pad=False),
108+
(torch.rand(1, 16, 54, 54),),
109+
),
94110
}
95111

96112

backends/arm/test/ops/test_max_pool.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,31 @@
3939
torch.rand(1, 16, 56, 56),
4040
[3, (1, 2), 1, 1, True],
4141
),
42+
"non_divisible_window_adjust_padding": lambda: (
43+
torch.rand(1, 16, 112, 112),
44+
[3, 2, 1],
45+
),
46+
"non_divisible_window_height_adjust_padding": lambda: (
47+
torch.rand(1, 16, 56, 56),
48+
[3, (2, 1), 1],
49+
),
50+
"non_divisible_window_width_adjust_padding": lambda: (
51+
torch.rand(1, 16, 56, 56),
52+
[3, (1, 2), 1],
53+
),
54+
"non_divisble_no_padding": lambda: (torch.rand(1, 16, 56, 56), [3, 2, 0]),
55+
"non_divisible_window_adjust_padding+input": lambda: (
56+
torch.rand(1, 16, 54, 54),
57+
[3, 3, 1],
58+
),
59+
"non_divisible_window_height_adjust_padding+input": lambda: (
60+
torch.rand(1, 16, 54, 54),
61+
[3, (3, 1), 1],
62+
),
63+
"non_divisible_window_width_adjust_padding+input": lambda: (
64+
torch.rand(1, 16, 54, 54),
65+
[3, (1, 3), 1],
66+
),
4267
}
4368

4469
test_data_suite_mult_batches = {

0 commit comments

Comments
 (0)