|
25 | 25 | from executorch.backends.cadence.aot.replace_ops import (
|
26 | 26 | ForceChannelLastForConvPass,
|
27 | 27 | MakeSliceAndCatDimOutermostPass,
|
| 28 | + ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, |
28 | 29 | ReplaceAddMMWithLinearPass,
|
29 | 30 | ReplaceAtenConvolutionWithJarvisConvolutionPass,
|
30 | 31 | ReplaceConstantPadNdWithSlicePass,
|
@@ -1971,3 +1972,103 @@ def test_empty_slice(self):
|
1971 | 1972 | ),
|
1972 | 1973 | 1,
|
1973 | 1974 | )
|
| 1975 | + |
| 1976 | +class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase): |
| 1977 | + def _get_adaptive_avg_pool_gm( |
| 1978 | + self, input_shape: Tuple[int], output_shape: Tuple[int] |
| 1979 | + ) -> torch.fx.GraphModule: |
| 1980 | + builder = GraphBuilder() |
| 1981 | + x = builder.placeholder("x", torch.randn(*input_shape)) |
| 1982 | + adaptive_avg_pool2d = builder.call_operator( |
| 1983 | + exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape) |
| 1984 | + ) |
| 1985 | + builder.output([adaptive_avg_pool2d]) |
| 1986 | + return builder.get_graph_module() |
| 1987 | + |
| 1988 | + def test_replace_adaptive_avg_pool_with_aten_avg_pool(self): |
| 1989 | + gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8)) |
| 1990 | + self.assertEqual( |
| 1991 | + len( |
| 1992 | + gm.graph.find_nodes( |
| 1993 | + op="call_function", |
| 1994 | + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, |
| 1995 | + ) |
| 1996 | + ), |
| 1997 | + 1, |
| 1998 | + ) |
| 1999 | + self.assertEqual( |
| 2000 | + len( |
| 2001 | + gm.graph.find_nodes( |
| 2002 | + op="call_function", |
| 2003 | + target=exir_ops.edge.aten.avg_pool2d.default, |
| 2004 | + ) |
| 2005 | + ), |
| 2006 | + 0, |
| 2007 | + ) |
| 2008 | + updated_gm = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()(gm).graph_module |
| 2009 | + self.assertEqual( |
| 2010 | + len( |
| 2011 | + updated_gm.graph.find_nodes( |
| 2012 | + op="call_function", target=exir_ops.edge.aten._adaptive_avg_pool2d.default |
| 2013 | + ) |
| 2014 | + ), |
| 2015 | + 0, |
| 2016 | + ) |
| 2017 | + avg_pool2d_nodes = updated_gm.graph.find_nodes( |
| 2018 | + op="call_function", target=exir_ops.edge.aten.avg_pool2d.default |
| 2019 | + ) |
| 2020 | + self.assertEqual( |
| 2021 | + len(avg_pool2d_nodes), |
| 2022 | + 1, |
| 2023 | + ) |
| 2024 | + avg_pool2d_node = avg_pool2d_nodes[0] |
| 2025 | + |
| 2026 | + self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16 |
| 2027 | + self.assertEqual(avg_pool2d_node.args[2], [16, 16]) # stride is 16, 16 |
| 2028 | + self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0 |
| 2029 | + self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False |
| 2030 | + self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True |
| 2031 | + self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None |
| 2032 | + |
| 2033 | + def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self): |
| 2034 | + gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9)) |
| 2035 | + self.assertEqual( |
| 2036 | + len( |
| 2037 | + gm.graph.find_nodes( |
| 2038 | + op="call_function", target=exir_ops.edge.aten._adaptive_avg_pool2d.default |
| 2039 | + ) |
| 2040 | + ), |
| 2041 | + 1, |
| 2042 | + ) |
| 2043 | + self.assertEqual( |
| 2044 | + len( |
| 2045 | + gm.graph.find_nodes( |
| 2046 | + op="call_function", target=exir_ops.edge.aten.avg_pool2d.default |
| 2047 | + ) |
| 2048 | + ), |
| 2049 | + 0, |
| 2050 | + ) |
| 2051 | + updated_gm = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()(gm).graph_module |
| 2052 | + self.assertEqual( |
| 2053 | + len( |
| 2054 | + updated_gm.graph.find_nodes( |
| 2055 | + op="call_function", target=exir_ops.edge.aten._adaptive_avg_pool2d.default |
| 2056 | + ) |
| 2057 | + ), |
| 2058 | + 0, |
| 2059 | + ) |
| 2060 | + avg_pool2d_nodes = updated_gm.graph.find_nodes( |
| 2061 | + op="call_function", target=exir_ops.edge.aten.avg_pool2d.default |
| 2062 | + ) |
| 2063 | + self.assertEqual( |
| 2064 | + len(avg_pool2d_nodes), |
| 2065 | + 1, |
| 2066 | + ) |
| 2067 | + avg_pool2d_node = avg_pool2d_nodes[0] |
| 2068 | + |
| 2069 | + self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16 |
| 2070 | + self.assertEqual(avg_pool2d_node.args[2], [14, 14]) # stride is 14, 14 |
| 2071 | + self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0 |
| 2072 | + self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False |
| 2073 | + self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True |
| 2074 | + self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None |
0 commit comments