Skip to content

Commit 8a8daf4

Browse files
Eashan Gargfacebook-github-bot
Eashan Garg
authored andcommitted
Pass to replace Adaptive Avg. Pool with Aten Avg. Pool (#10818)
Summary: Seeing exir_ops.edge.aten._adaptive_avg_pool2d.default nodes in some graphs, pass to replace these with exir_ops.edge.aten.avg_pool2d.default Reviewed By: zonglinpeng Differential Revision: D74559775
1 parent d0360b7 commit 8a8daf4

File tree

2 files changed

+151
-0
lines changed

2 files changed

+151
-0
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,6 +2401,55 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
24012401
return result
24022402

24032403

2404+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2405+
class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(ExportPass):
2406+
"""
2407+
Replace the aten adaptive avg_pool op with the aten avg_pool2d op.
2408+
"""
2409+
2410+
def call_operator(self, op, args, kwargs, meta):
2411+
# Only continue for avg_pool op
2412+
if op not in {exir_ops.edge.aten._adaptive_avg_pool2d.default}:
2413+
return super().call_operator(op, args, kwargs, meta)
2414+
2415+
# Get the input tensor
2416+
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
2417+
# Permute NCHW to NHWC for computation
2418+
in_tensor_permuted = in_tensor.permute(0, 2, 3, 1)
2419+
in_tensor_shape = in_tensor_permuted.shape
2420+
2421+
output_size = args[1]
2422+
num_dims = len(output_size)
2423+
2424+
# Compute stride and kernel_size, then set default values for other arguments
2425+
stride = [(in_tensor_shape[i + 1] // output_size[i]) for i in range(num_dims)]
2426+
kernel_size = [
2427+
in_tensor_shape[i + 1] - (output_size[i] - 1) * stride[i]
2428+
for i in range(num_dims)
2429+
]
2430+
padding = [0] * num_dims
2431+
ceil_mode = False
2432+
count_include_pad = True
2433+
divisor_override = None
2434+
2435+
# Create a new avg_pool node with the updated args
2436+
new_args = (
2437+
args[0],
2438+
kernel_size,
2439+
stride,
2440+
padding,
2441+
ceil_mode,
2442+
count_include_pad,
2443+
divisor_override,
2444+
)
2445+
return super().call_operator(
2446+
exir_ops.edge.aten.avg_pool2d.default,
2447+
new_args,
2448+
kwargs,
2449+
meta,
2450+
)
2451+
2452+
24042453
# This class encapsulates all the functions that replace/switch one op in the
24052454
# graph with another.
24062455
class CadenceReplaceOpsInGraph:
@@ -2438,6 +2487,7 @@ class CadenceReplaceOpsInGraph:
24382487
ReplacePT2QuantWithCadenceQuantPass,
24392488
ReplacePT2DequantWithCadenceDequantPass,
24402489
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
2490+
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
24412491
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
24422492
ReplaceWhereWithFullArgsWithWhereScalar,
24432493
ReplaceGeluWithApproximateGeluPass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from executorch.backends.cadence.aot.replace_ops import (
2626
ForceChannelLastForConvPass,
2727
MakeSliceAndCatDimOutermostPass,
28+
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
2829
ReplaceAddMMWithLinearPass,
2930
ReplaceAtenConvolutionWithJarvisConvolutionPass,
3031
ReplaceConstantPadNdWithSlicePass,
@@ -1971,3 +1972,103 @@ def test_empty_slice(self):
19711972
),
19721973
1,
19731974
)
1975+
1976+
class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase):
1977+
def _get_adaptive_avg_pool_gm(
1978+
self, input_shape: Tuple[int], output_shape: Tuple[int]
1979+
) -> torch.fx.GraphModule:
1980+
builder = GraphBuilder()
1981+
x = builder.placeholder("x", torch.randn(*input_shape))
1982+
adaptive_avg_pool2d = builder.call_operator(
1983+
exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape)
1984+
)
1985+
builder.output([adaptive_avg_pool2d])
1986+
return builder.get_graph_module()
1987+
1988+
def test_replace_adaptive_avg_pool_with_aten_avg_pool(self):
1989+
gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8))
1990+
self.assertEqual(
1991+
len(
1992+
gm.graph.find_nodes(
1993+
op="call_function",
1994+
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
1995+
)
1996+
),
1997+
1,
1998+
)
1999+
self.assertEqual(
2000+
len(
2001+
gm.graph.find_nodes(
2002+
op="call_function",
2003+
target=exir_ops.edge.aten.avg_pool2d.default,
2004+
)
2005+
),
2006+
0,
2007+
)
2008+
updated_gm = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()(gm).graph_module
2009+
self.assertEqual(
2010+
len(
2011+
updated_gm.graph.find_nodes(
2012+
op="call_function", target=exir_ops.edge.aten._adaptive_avg_pool2d.default
2013+
)
2014+
),
2015+
0,
2016+
)
2017+
avg_pool2d_nodes = updated_gm.graph.find_nodes(
2018+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
2019+
)
2020+
self.assertEqual(
2021+
len(avg_pool2d_nodes),
2022+
1,
2023+
)
2024+
avg_pool2d_node = avg_pool2d_nodes[0]
2025+
2026+
self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16
2027+
self.assertEqual(avg_pool2d_node.args[2], [16, 16]) # stride is 16, 16
2028+
self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0
2029+
self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False
2030+
self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True
2031+
self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None
2032+
2033+
def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self):
2034+
gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9))
2035+
self.assertEqual(
2036+
len(
2037+
gm.graph.find_nodes(
2038+
op="call_function", target=exir_ops.edge.aten._adaptive_avg_pool2d.default
2039+
)
2040+
),
2041+
1,
2042+
)
2043+
self.assertEqual(
2044+
len(
2045+
gm.graph.find_nodes(
2046+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
2047+
)
2048+
),
2049+
0,
2050+
)
2051+
updated_gm = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()(gm).graph_module
2052+
self.assertEqual(
2053+
len(
2054+
updated_gm.graph.find_nodes(
2055+
op="call_function", target=exir_ops.edge.aten._adaptive_avg_pool2d.default
2056+
)
2057+
),
2058+
0,
2059+
)
2060+
avg_pool2d_nodes = updated_gm.graph.find_nodes(
2061+
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
2062+
)
2063+
self.assertEqual(
2064+
len(avg_pool2d_nodes),
2065+
1,
2066+
)
2067+
avg_pool2d_node = avg_pool2d_nodes[0]
2068+
2069+
self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16
2070+
self.assertEqual(avg_pool2d_node.args[2], [14, 14]) # stride is 14, 14
2071+
self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0
2072+
self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False
2073+
self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True
2074+
self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None

0 commit comments

Comments
 (0)