Skip to content

Commit 4646af4

Browse files
Eashan Gargfacebook-github-bot
Eashan Garg
authored andcommitted
Pass to replace Adaptive Avg. Pool with Aten Avg. Pool (#10818)
Summary: Seeing exir_ops.edge.aten._adaptive_avg_pool2d.default nodes in some graphs, pass to replace these with exir_ops.edge.aten.avg_pool2d.default Differential Revision: D74559775
1 parent 6cceab6 commit 4646af4

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# pyre-unsafe
1818

19+
import logging
1920
import math
2021
import operator
2122
from operator import neg
@@ -2376,6 +2377,66 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
23762377
return result
23772378

23782379

2380+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2381+
class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(ExportPass):
2382+
"""
2383+
Replace the aten adaptive avg_pool op with the aten avg_pool2d op.
2384+
"""
2385+
2386+
def call_operator(self, op, args, kwargs, meta):
2387+
# Only continue for avg_pool op
2388+
if op not in {exir_ops.edge.aten._adaptive_avg_pool2d.default}:
2389+
return super().call_operator(op, args, kwargs, meta)
2390+
2391+
# Get the input tensor
2392+
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
2393+
# Permute NCHW to NHWC for computation
2394+
in_tensor_permuted = in_tensor.permute(0, 2, 3, 1)
2395+
in_tensor_shape = in_tensor_permuted.shape
2396+
2397+
output_size = args[1]
2398+
num_dims = len(output_size)
2399+
2400+
# TODO: If in_tensor_shape is not a multiple of output size,
2401+
# this pass will not work. T224984800
2402+
dim_multiples = [
2403+
(in_tensor_shape[i + 1] % output_size[i]) == 0 for i in range(num_dims)
2404+
]
2405+
if not all(dim_multiples):
2406+
logging.info(
2407+
f"Unable to replace adaptive average pool with average pool. Input tensor shape of {in_tensor_shape} is not a multiple of output size: {output_size}"
2408+
)
2409+
return super().call_operator(op, args, kwargs, meta)
2410+
2411+
# Compute stride and kernel_size, then set default values for other arguments
2412+
stride = [(in_tensor_shape[i + 1] // output_size[i]) for i in range(num_dims)]
2413+
kernel_size = [
2414+
in_tensor_shape[i + 1] - (output_size[i] - 1) * stride[i]
2415+
for i in range(num_dims)
2416+
]
2417+
padding = [0] * num_dims
2418+
ceil_mode = False
2419+
count_include_pad = True
2420+
divisor_override = None
2421+
2422+
# Create a new avg_pool node with the updated args
2423+
new_args = (
2424+
args[0],
2425+
kernel_size,
2426+
stride,
2427+
padding,
2428+
ceil_mode,
2429+
count_include_pad,
2430+
divisor_override,
2431+
)
2432+
return super().call_operator(
2433+
exir_ops.edge.aten.avg_pool2d.default,
2434+
new_args,
2435+
kwargs,
2436+
meta,
2437+
)
2438+
2439+
23792440
# This class encapsulates all the functions that replace/switch one op in the
23802441
# graph with another.
23812442
class CadenceReplaceOpsInGraph:
@@ -2412,6 +2473,7 @@ class CadenceReplaceOpsInGraph:
24122473
ReplacePT2QuantWithCadenceQuantPass,
24132474
ReplacePT2DequantWithCadenceDequantPass,
24142475
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
2476+
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
24152477
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
24162478
ReplaceWhereWithFullArgsWithWhereScalar,
24172479
ReplaceGeluWithApproximateGeluPass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from executorch.backends.cadence.aot.replace_ops import (
2626
ForceChannelLastForConvPass,
2727
MakeSliceAndCatDimOutermostPass,
28+
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
2829
ReplaceAddMMWithLinearPass,
2930
ReplaceAtenConvolutionWithJarvisConvolutionPass,
3031
ReplaceConstantPadNdWithSlicePass,
@@ -1939,3 +1940,100 @@ def test_empty_slice(self):
19391940
),
19401941
1,
19411942
)
1943+
1944+
1945+
class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase):
1946+
def _get_adaptive_avg_pool_gm(
1947+
self, input_shape: Tuple[int], output_shape: Tuple[int]
1948+
) -> torch.fx.GraphModule:
1949+
builder = GraphBuilder()
1950+
x = builder.placeholder("x", torch.randn(*input_shape))
1951+
adaptive_avg_pool2d = builder.call_operator(
1952+
exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape)
1953+
)
1954+
builder.output([adaptive_avg_pool2d])
1955+
return builder.get_graph_module()
1956+
1957+
def test_replace_adaptive_avg_pool_with_aten_avg_pool(self):
1958+
gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8))
1959+
self.assertEqual(
1960+
len(
1961+
gm.graph.find_nodes(
1962+
op="call_function",
1963+
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
1964+
)
1965+
),
1966+
1,
1967+
)
1968+
self.assertEqual(
1969+
len(
1970+
gm.graph.find_nodes(
1971+
op="call_function",
1972+
target=exir_ops.edge.aten.avg_pool2d.default,
1973+
)
1974+
),
1975+
0,
1976+
)
1977+
updated_gm = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()(gm).graph_module
1978+
self.assertEqual(
1979+
len(
1980+
updated_gm.graph.find_nodes(
1981+
op="call_function",
1982+
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
1983+
)
1984+
),
1985+
0,
1986+
)
1987+
avg_pool2d_nodes = updated_gm.graph.find_nodes(
1988+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
1989+
)
1990+
self.assertEqual(
1991+
len(avg_pool2d_nodes),
1992+
1,
1993+
)
1994+
avg_pool2d_node = avg_pool2d_nodes[0]
1995+
1996+
self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16
1997+
self.assertEqual(avg_pool2d_node.args[2], [16, 16]) # stride is 16, 16
1998+
self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0
1999+
self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False
2000+
self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True
2001+
self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None
2002+
2003+
def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self):
2004+
gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9))
2005+
self.assertEqual(
2006+
len(
2007+
gm.graph.find_nodes(
2008+
op="call_function",
2009+
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
2010+
)
2011+
),
2012+
1,
2013+
)
2014+
self.assertEqual(
2015+
len(
2016+
gm.graph.find_nodes(
2017+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
2018+
)
2019+
),
2020+
0,
2021+
)
2022+
# Shapes are not multiples of each other, so pass will not trigger
2023+
updated_gm = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()(gm).graph_module
2024+
self.assertEqual(
2025+
len(
2026+
updated_gm.graph.find_nodes(
2027+
op="call_function",
2028+
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
2029+
)
2030+
),
2031+
1,
2032+
)
2033+
avg_pool2d_nodes = updated_gm.graph.find_nodes(
2034+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
2035+
)
2036+
self.assertEqual(
2037+
len(avg_pool2d_nodes),
2038+
0,
2039+
)

0 commit comments

Comments
 (0)