Skip to content

Commit d289a80

Browse files
hsharma35facebook-github-bot
authored andcommitted
Add a pass to replace nodes with empty tensors with full.
Summary: Remove subgraphs of ops that produce empty tensors at the end. `ReplaceEmptyTensorsWithFullPass` both does the replacement and dead code elimination. Reviewed By: zonglinpeng Differential Revision: D68907459
1 parent a5c7609 commit d289a80

File tree

3 files changed

+82
-3
lines changed

3 files changed

+82
-3
lines changed

backends/cadence/aot/graph_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from typing import Optional, Sequence, Union
77

88
import torch
9-
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
9+
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue, Argument
1010
from torch._dispatch.python import enable_python_dispatcher
1111
from torch._subclasses import FakeTensor, FakeTensorMode
12-
from torch.fx.node import Argument, Target
12+
from torch.fx.node import Target
1313
from torch.utils import _pytree as pytree
1414

1515

backends/cadence/aot/replace_ops.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,11 +2070,32 @@ def call_operator(
20702070
meta,
20712071
)
20722072

2073+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2074+
class ReplaceEmptyTensorsWithFullPass(ExportPass):
2075+
"""Replaces nodes that produce empty tensors with full nodes."""
2076+
2077+
def call_operator(self, op, args, kwargs, meta):
2078+
val = meta.data.get("val", None)
2079+
if isinstance(val, torch.Tensor) and val.numel() == 0:
2080+
return super().call_operator(
2081+
exir_ops.edge.aten.full.default,
2082+
args=(val.shape, 0),
2083+
kwargs={"dtype": val.dtype},
2084+
meta=meta,
2085+
)
2086+
return super().call_operator(op, args, kwargs, meta)
2087+
2088+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2089+
ret = super().call(graph_module)
2090+
modified = ret.graph_module.graph.eliminate_dead_code() or ret.modified
2091+
return PassResult(ret.graph_module, modified)
2092+
20732093

20742094
# This class encapsulates all the functions that replace/switch one op in the
20752095
# graph with another.
20762096
class CadenceReplaceOpsInGraph:
20772097
passes = [
2098+
ReplaceEmptyTensorsWithFullPass,
20782099
ReplaceFunctionallyEquivalentOpTargets,
20792100
ReplaceTCopyWithTransposePass,
20802101
ReplacePermuteWithTransposePass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import torch.nn.functional as F
88
from executorch.backends.cadence.aot import compiler
99
from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2
10-
from executorch.backends.cadence.aot.graph_builder import single_op_builder
10+
from executorch.backends.cadence.aot.graph_builder import (
11+
GraphBuilder,
12+
single_op_builder,
13+
)
1114
from executorch.backends.cadence.aot.pass_utils import count_node
1215
from executorch.backends.cadence.aot.replace_ops import (
1316
ForceChannelLastForConvPass,
@@ -18,6 +21,7 @@
1821
ReplaceConstantPadNdWithSlicePass,
1922
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
2023
ReplaceConvWithIm2RowAndLinear,
24+
ReplaceEmptyTensorsWithFullPass,
2125
ReplaceFunctionallyEquivalentOpTargets,
2226
ReplaceIm2RowWithViewPass,
2327
ReplaceLinearWithFullyConnectedOpPass,
@@ -1681,3 +1685,57 @@ def test_cat_insert_transpose(self):
16811685
count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int),
16821686
3,
16831687
)
1688+
1689+
1690+
class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase):
1691+
def _get_slice_empty_gm(self) -> torch.fx.GraphModule:
1692+
builder = GraphBuilder()
1693+
x = builder.placeholder("x", torch.randn(4))
1694+
# This is empty (numel == 0).
1695+
slice0 = builder.call_operator(
1696+
exir_ops.edge.aten.slice_copy.Tensor, (x, 0, 0, 0)
1697+
)
1698+
# Copy of x.
1699+
slice1 = builder.call_operator(exir_ops.edge.aten.slice_copy.Tensor, (x,))
1700+
cat = builder.call_operator(
1701+
exir_ops.edge.aten.cat.default,
1702+
((slice0, slice1),),
1703+
)
1704+
builder.output([cat])
1705+
return builder.get_graph_module()
1706+
1707+
def test_slice_no_transpose_if_already_outermost(self):
1708+
gm = self._get_slice_empty_gm()
1709+
self.assertEqual(
1710+
len(
1711+
gm.graph.find_nodes(
1712+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
1713+
)
1714+
),
1715+
2,
1716+
)
1717+
self.assertEqual(
1718+
len(
1719+
gm.graph.find_nodes(
1720+
op="call_function", target=exir_ops.edge.aten.full.default
1721+
)
1722+
),
1723+
0,
1724+
)
1725+
updated_gm = ReplaceEmptyTensorsWithFullPass()(gm).graph_module
1726+
self.assertEqual(
1727+
len(
1728+
updated_gm.graph.find_nodes(
1729+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
1730+
)
1731+
),
1732+
1,
1733+
)
1734+
self.assertEqual(
1735+
len(
1736+
updated_gm.graph.find_nodes(
1737+
op="call_function", target=exir_ops.edge.aten.full.default
1738+
)
1739+
),
1740+
1,
1741+
)

0 commit comments

Comments
 (0)