Skip to content

Commit d069d65

Browse files
authored
Add pass to convert kwargs to args + populate optional args.
Differential Revision: D74510388 Pull Request resolved: #10857
1 parent e09f33c commit d069d65

File tree

3 files changed

+74
-2
lines changed

3 files changed

+74
-2
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ python_unittest(
367367
"fbsource//third-party/pypi/parameterized:parameterized",
368368
"//caffe2:torch",
369369
"//executorch/backends/cadence/aot:compiler",
370+
"//executorch/backends/cadence/aot:graph_builder",
370371
"//executorch/backends/cadence/aot:ops_registrations",
371372
"//executorch/backends/cadence/aot:pass_utils",
372373
"//executorch/backends/cadence/aot:simplify_ops",

backends/cadence/aot/simplify_ops.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
CadencePassAttribute,
1717
register_cadence_pass,
1818
)
19-
2019
from executorch.exir.dialects._ops import ops as exir_ops
20+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2121
from executorch.exir.pass_base import ExportPass, ProxyValue
22+
from torch.fx.operator_schemas import get_signature_for_torch_op
2223

2324

2425
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -109,8 +110,44 @@ def call_operator(self, op, args, kwargs, meta):
109110
return super().call_operator(op, new_args, kwargs, meta)
110111

111112

113+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
114+
class BindOptionalArgsPass(ExportPass):
115+
"""Bind all optional args and kwargs."""
116+
117+
def call_operator(self, op, args, kwargs, meta):
118+
if not isinstance(op, EdgeOpOverload):
119+
return super().call_operator(op, args, kwargs, meta)
120+
assert callable(op)
121+
122+
torch_op_schemas = get_signature_for_torch_op(op._op)
123+
if len(torch_op_schemas) == 0:
124+
return super().call_operator(op, args, kwargs, meta)
125+
126+
matched_schemas = []
127+
# Iterate through all of the schema until we find one that matches
128+
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
129+
# values. If none matches, `new_args_and_kwargs` will be None
130+
for candidate_signature in torch_op_schemas:
131+
try:
132+
candidate_signature.bind(*args, **kwargs)
133+
matched_schemas.append(candidate_signature)
134+
except TypeError:
135+
continue
136+
137+
if len(matched_schemas) != 1:
138+
# Did not match any schema. Cannot normalize
139+
return super().call_operator(op, args, kwargs, meta)
140+
141+
sig = matched_schemas[0]
142+
bound_args = sig.bind(*args, **kwargs)
143+
bound_args.apply_defaults()
144+
145+
return super().call_operator(op, bound_args.args, bound_args.kwargs, meta)
146+
147+
112148
# This class encapsulates all the functions that simplify the op's args
113149
class CadenceSimplifyOpsInGraph:
114150
passes = [
115151
SimplifySliceOpPass,
152+
BindOptionalArgsPass,
116153
]

backends/cadence/aot/tests/test_simplify_ops_passes.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
1515
from executorch.backends.cadence.aot.compiler import export_to_edge
16+
from executorch.backends.cadence.aot.graph_builder import single_op_builder
1617
from executorch.backends.cadence.aot.pass_utils import count_node
17-
from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
18+
from executorch.backends.cadence.aot.simplify_ops import (
19+
BindOptionalArgsPass,
20+
SimplifySliceOpPass,
21+
)
1822
from executorch.exir.dialects._ops import ops as exir_ops
1923
from parameterized.parameterized import parameterized
2024
from torch.fx.passes.infra.pass_base import PassResult
@@ -112,3 +116,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
112116
self.assertEqual(
113117
count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1
114118
)
119+
120+
def test_simplify_slice_op_args(self) -> None:
121+
x = torch.rand(4, 5)
122+
gm = single_op_builder(
123+
placeholders=(x,),
124+
op=exir_ops.edge.aten.slice_copy.Tensor,
125+
args=(x, 1),
126+
kwargs={"end": 3},
127+
)
128+
self.assertEqual(
129+
[
130+
(n.args[1:], n.kwargs)
131+
for n in gm.graph.find_nodes(
132+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
133+
)
134+
],
135+
[((1,), {"end": 3})],
136+
)
137+
138+
gm = BindOptionalArgsPass().call(gm).graph_module
139+
140+
self.assertEqual(
141+
[
142+
(n.args[1:], n.kwargs)
143+
for n in gm.graph.find_nodes(
144+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
145+
)
146+
],
147+
[((1, None, 3, 1), {})],
148+
)

0 commit comments

Comments
 (0)