Skip to content

Pass to replace Adaptive Avg. Pool with Aten Avg. Pool #10818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# pyre-unsafe

import logging
import math
import operator
from operator import neg
Expand Down Expand Up @@ -2376,6 +2377,66 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
return result


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(ExportPass):
"""
Replace the aten adaptive avg_pool op with the aten avg_pool2d op.
"""

def call_operator(self, op, args, kwargs, meta):
# Only continue for avg_pool op
if op not in {exir_ops.edge.aten._adaptive_avg_pool2d.default}:
return super().call_operator(op, args, kwargs, meta)

# Get the input tensor
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
# Permute NCHW to NHWC for computation
in_tensor_permuted = in_tensor.permute(0, 2, 3, 1)
in_tensor_shape = in_tensor_permuted.shape

output_size = args[1]
num_dims = len(output_size)

# TODO: If in_tensor_shape is not a multiple of output size,
# this pass will not work. T224984800
dim_multiples = [
(in_tensor_shape[i + 1] % output_size[i]) == 0 for i in range(num_dims)
]
if not all(dim_multiples):
logging.info(
f"Unable to replace adaptive average pool with average pool. Input tensor shape of {in_tensor_shape} is not a multiple of output size: {output_size}"
)
return super().call_operator(op, args, kwargs, meta)

# Compute stride and kernel_size, then set default values for other arguments
stride = [(in_tensor_shape[i + 1] // output_size[i]) for i in range(num_dims)]
kernel_size = [
in_tensor_shape[i + 1] - (output_size[i] - 1) * stride[i]
for i in range(num_dims)
]
padding = [0] * num_dims
ceil_mode = False
count_include_pad = True
divisor_override = None

# Create a new avg_pool node with the updated args
new_args = (
args[0],
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
)
return super().call_operator(
exir_ops.edge.aten.avg_pool2d.default,
new_args,
kwargs,
meta,
)


# This class encapsulates all the functions that replace/switch one op in the
# graph with another.
class CadenceReplaceOpsInGraph:
Expand Down Expand Up @@ -2412,6 +2473,7 @@ class CadenceReplaceOpsInGraph:
ReplacePT2QuantWithCadenceQuantPass,
ReplacePT2DequantWithCadenceDequantPass,
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
ReplaceWhereWithFullArgsWithWhereScalar,
ReplaceGeluWithApproximateGeluPass,
Expand Down
98 changes: 98 additions & 0 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from executorch.backends.cadence.aot.replace_ops import (
ForceChannelLastForConvPass,
MakeSliceAndCatDimOutermostPass,
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
ReplaceAddMMWithLinearPass,
ReplaceAtenConvolutionWithJarvisConvolutionPass,
ReplaceConstantPadNdWithSlicePass,
Expand Down Expand Up @@ -1939,3 +1940,100 @@ def test_empty_slice(self):
),
1,
)


class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase):
def _get_adaptive_avg_pool_gm(
self, input_shape: Tuple[int], output_shape: Tuple[int]
) -> torch.fx.GraphModule:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*input_shape))
adaptive_avg_pool2d = builder.call_operator(
exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape)
)
builder.output([adaptive_avg_pool2d])
return builder.get_graph_module()

def test_replace_adaptive_avg_pool_with_aten_avg_pool(self):
gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8))
self.assertEqual(
len(
gm.graph.find_nodes(
op="call_function",
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
)
),
1,
)
self.assertEqual(
len(
gm.graph.find_nodes(
op="call_function",
target=exir_ops.edge.aten.avg_pool2d.default,
)
),
0,
)
updated_gm = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()(gm).graph_module
self.assertEqual(
len(
updated_gm.graph.find_nodes(
op="call_function",
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
)
),
0,
)
avg_pool2d_nodes = updated_gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
)
self.assertEqual(
len(avg_pool2d_nodes),
1,
)
avg_pool2d_node = avg_pool2d_nodes[0]

self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16
self.assertEqual(avg_pool2d_node.args[2], [16, 16]) # stride is 16, 16
self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0
self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False
self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True
self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None

def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self):
gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9))
self.assertEqual(
len(
gm.graph.find_nodes(
op="call_function",
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
)
),
1,
)
self.assertEqual(
len(
gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
)
),
0,
)
# Shapes are not multiples of each other, so pass will not trigger
updated_gm = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()(gm).graph_module
self.assertEqual(
len(
updated_gm.graph.find_nodes(
op="call_function",
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
)
),
1,
)
avg_pool2d_nodes = updated_gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
)
self.assertEqual(
len(avg_pool2d_nodes),
0,
)
Loading