|
1 | 1 | # Copyright 2024-2025 Arm Limited and/or its affiliates.
|
2 |
| -# All rights reserved. |
3 | 2 | #
|
4 | 3 | # This source code is licensed under the BSD-style license found in the
|
5 | 4 | # LICENSE file in the root directory of this source tree.
|
6 | 5 |
|
7 | 6 | # pyre-unsafe
|
8 | 7 |
|
9 |
| -from typing import cast |
| 8 | +from typing import cast, TypeAlias |
10 | 9 |
|
11 | 10 | import torch.fx
|
12 | 11 | from executorch.backends.arm._passes.arm_pass_utils import create_node
|
13 | 12 | from executorch.exir.dialects._ops import ops as exir_ops
|
14 | 13 | from executorch.exir.pass_base import ExportPass, PassResult
|
15 | 14 |
|
| 15 | +Slices: TypeAlias = list[tuple[int, int, int]] |
16 | 16 |
|
17 |
| -def conv_remainder(input_length, pad, dilation, weight, stride): |
| 17 | +conv2d_op = exir_ops.edge.aten.convolution.default |
| 18 | +max_pooling_op = exir_ops.edge.aten.max_pool2d.default |
| 19 | +avg_pooling_op = exir_ops.edge.aten.avg_pool2d.default |
| 20 | +slice_op = exir_ops.edge.aten.slice_copy.Tensor |
| 21 | + |
| 22 | +valid_operators = [conv2d_op, max_pooling_op, avg_pooling_op] |
| 23 | + |
| 24 | + |
| 25 | +def conv_remainder(input_length, pad, dilation, weight, stride) -> int: |
18 | 26 | """
|
19 | 27 | Returns the remainder of input_length; given the padding, dilation, stride,
|
20 | 28 | and kernel size.
|
21 | 29 | """
|
22 | 30 | return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride
|
23 | 31 |
|
24 | 32 |
|
25 |
| -class SizeAdjustConv2DPass(ExportPass): |
| 33 | +def pooling_remainder(input_size, pad, kernel_size, stride) -> int: |
| 34 | + """ |
| 35 | + Returns the remainder of input_length; given the padding, stride, and |
| 36 | + kernel size. |
| 37 | + """ |
| 38 | + return (input_size + 2 * pad - kernel_size) % stride |
| 39 | + |
| 40 | + |
| 41 | +def get_slices_conv2d(conv_node: torch.fx.Node) -> Slices: |
| 42 | + slices = [] |
| 43 | + |
| 44 | + input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = conv_node.args |
| 45 | + weight_shape = cast(torch.fx.Node, weight).meta["val"].shape |
| 46 | + input_shape = cast(torch.fx.Node, input_node).meta["val"].shape |
| 47 | + |
| 48 | + for stride, pad, dilation, dim in zip( |
| 49 | + cast(list, stride_hw), |
| 50 | + cast(list, pad_hw), |
| 51 | + cast(list, dilation_hw), |
| 52 | + (2, 3), |
| 53 | + ): |
| 54 | + remainder = conv_remainder( |
| 55 | + input_shape[dim], pad, dilation, weight_shape[dim], stride |
| 56 | + ) |
| 57 | + if remainder > pad: |
| 58 | + adjustment = remainder - pad |
| 59 | + args = (dim, 0, input_shape[dim] - adjustment) |
| 60 | + slices.append(args) |
| 61 | + |
| 62 | + return slices |
| 63 | + |
| 64 | + |
| 65 | +def get_slices_pooling(pooling_node: torch.fx.Node) -> Slices: |
| 66 | + slices = [] |
| 67 | + |
| 68 | + input_node = pooling_node.args[0] |
| 69 | + kernel_size = pooling_node.args[1] |
| 70 | + stride = pooling_node.args[2] |
| 71 | + padding = pooling_node.args[3] if len(pooling_node.args) >= 4 else [0, 0] |
| 72 | + |
| 73 | + # For the loop below, padding must be a list |
| 74 | + if isinstance(padding, int): |
| 75 | + padding = [padding, padding] |
| 76 | + |
| 77 | + input_shape = cast(torch.fx.Node, input_node).meta["val"].shape |
| 78 | + |
| 79 | + for kernel_length, stride_length, pad_size, dim in zip( |
| 80 | + cast(list, kernel_size), |
| 81 | + cast(list, stride), |
| 82 | + cast(list, padding), |
| 83 | + (2, 3), |
| 84 | + ): |
| 85 | + remainder = pooling_remainder( |
| 86 | + input_shape[dim], pad_size, kernel_length, stride_length |
| 87 | + ) |
| 88 | + if remainder > pad_size: |
| 89 | + adjustment = remainder - pad_size |
| 90 | + args = (dim, 0, input_shape[dim] - adjustment) |
| 91 | + slices.append(args) |
| 92 | + |
| 93 | + return slices |
| 94 | + |
| 95 | + |
| 96 | +def get_slices(node: torch.fx.Node) -> Slices: |
| 97 | + """ |
| 98 | + Returns the remainder of input_length; given graph Node. |
| 99 | + """ |
| 100 | + if node.target == conv2d_op: |
| 101 | + return get_slices_conv2d(node) |
| 102 | + elif node.target == max_pooling_op or node.target == avg_pooling_op: |
| 103 | + return get_slices_pooling(node) |
| 104 | + else: |
| 105 | + raise ValueError(f"Unsupported node target, was expecting {valid_operators}") |
| 106 | + |
| 107 | + |
| 108 | +def is_valid_operator(node: torch.fx.Node) -> bool: |
| 109 | + if node.target == conv2d_op: |
| 110 | + return True |
| 111 | + elif node.target == max_pooling_op: |
| 112 | + dilation = node.args[4] if len(node.args) >= 5 else 1 |
| 113 | + ceil_mode = node.args[5] if len(node.args) >= 6 else False |
| 114 | + |
| 115 | + # Dilation should be handled first by DecomposeMaxPool2DPass |
| 116 | + if isinstance(dilation, int): |
| 117 | + if dilation > 1: |
| 118 | + raise ValueError( |
| 119 | + "Expected max_pool2d with dilation = 1, has DecomposeMaxPool2DPass been run?" |
| 120 | + ) |
| 121 | + else: |
| 122 | + dilation = cast(list, dilation) |
| 123 | + if dilation[0] > 1 or dilation[1] > 1: |
| 124 | + raise ValueError( |
| 125 | + "Expected max_pool2d with dilation = [1, 1], has DecomposeMaxPool2DPass been run?" |
| 126 | + ) |
| 127 | + |
| 128 | + # If using ceil mode for rounding, the input does not need adjusting |
| 129 | + return not ceil_mode |
| 130 | + elif node.target == avg_pooling_op: |
| 131 | + ceil_mode = node.args[4] if len(node.args) >= 5 else False |
| 132 | + count_include_pad = node.args[5] if len(node.args) >= 6 else True |
| 133 | + divisor_override = node.args[6] if len(node.args) >= 7 else None |
| 134 | + |
| 135 | + return not ceil_mode and not count_include_pad and divisor_override is None |
| 136 | + |
| 137 | + return False |
| 138 | + |
| 139 | + |
| 140 | +class SizeAdjustInputPass(ExportPass): |
26 | 141 | """
|
27 |
| - Adjust the convolution input size to match the kernel size, padding, stride, |
28 |
| - and dilation parameters. Pytorch allows the input and kernel shape to not |
29 |
| - "match", in which case the remaining rows/columns are truncated. However, |
30 |
| - matching the size is a requirement in the TOSA specification. In case the |
31 |
| - input and kernel shape do not match, the following is done to meet the |
32 |
| - specification: |
| 142 | + Adjusts the input size to Conv2D and Pooling operators. PyTorch allows |
| 143 | + the input and kernel shape to not "match", in which case the remaining |
| 144 | + rows/columns are truncated. However, matching the size is a requirement |
| 145 | + in the TOSA specification. In case the input and kernel shape do not |
| 146 | + match, the following is performed to meet the specification: |
33 | 147 |
|
34 | 148 | 1) The padding is truncated (done in the node visitor)
|
35 | 149 | 2) (if neccessary) The input is truncated (done in this pass)."
|
@@ -71,52 +185,33 @@ class SizeAdjustConv2DPass(ExportPass):
|
71 | 185 | input.
|
72 | 186 | """
|
73 | 187 |
|
74 |
| - conv2d_op = exir_ops.edge.aten.convolution.default |
75 |
| - slice_op = exir_ops.edge.aten.slice_copy.Tensor |
76 |
| - |
77 |
| - def call(self, graph_module: torch.fx.GraphModule): |
| 188 | + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
78 | 189 | graph = graph_module.graph
|
79 | 190 | modified_graph = False
|
80 | 191 | for node in graph.nodes:
|
81 | 192 | if node.op != "call_function":
|
82 | 193 | continue
|
83 |
| - if node.target != self.conv2d_op: |
| 194 | + if not is_valid_operator(node): |
84 | 195 | continue
|
85 | 196 |
|
86 |
| - conv_node = cast(torch.fx.Node, node) |
87 |
| - input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = ( |
88 |
| - conv_node.args |
89 |
| - ) |
90 |
| - weight_shape = cast(torch.fx.Node, weight).meta["val"].shape |
91 |
| - input_shape = cast(torch.fx.Node, input_node).meta["val"].shape |
92 |
| - |
93 |
| - slice_args = [] |
94 |
| - for stride, pad, dilation, dim in zip( |
95 |
| - cast(list, stride_hw), |
96 |
| - cast(list, pad_hw), |
97 |
| - cast(list, dilation_hw), |
98 |
| - (2, 3), |
99 |
| - ): |
100 |
| - remainder = conv_remainder( |
101 |
| - input_shape[dim], pad, dilation, weight_shape[dim], stride |
102 |
| - ) |
103 |
| - if remainder > pad: |
104 |
| - adjustment = remainder - pad |
105 |
| - args = (dim, 0, input_shape[dim] - adjustment) |
106 |
| - slice_args.append(args) |
| 197 | + target_node = cast(torch.fx.Node, node) |
| 198 | + slice_args = get_slices(target_node) |
| 199 | + |
107 | 200 | if len(slice_args) == 0:
|
108 | 201 | continue
|
109 | 202 |
|
| 203 | + parent_node = node.args[0] |
110 | 204 | with graph_module.graph.inserting_before(node):
|
111 |
| - last_node = cast(torch.fx.Node, input_node) |
| 205 | + last_node = cast(torch.fx.Node, parent_node) |
112 | 206 | for args in slice_args:
|
113 |
| - slice_node = create_node(graph, self.slice_op, (last_node,) + args) |
| 207 | + slice_node = create_node(graph, slice_op, (last_node,) + args) |
114 | 208 | last_node = slice_node
|
115 |
| - conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node) |
| 209 | + node.replace_input_with(cast(torch.fx.Node, parent_node), last_node) |
116 | 210 | modified_graph = True
|
117 | 211 |
|
118 | 212 | if modified_graph:
|
119 | 213 | graph_module = super().call(graph_module).graph_module
|
120 | 214 | graph.eliminate_dead_code()
|
121 | 215 | graph_module.recompile()
|
| 216 | + |
122 | 217 | return PassResult(graph_module, True)
|
0 commit comments