Skip to content

Commit 4ef8205

Browse files
jfix71facebook-github-bot
authored andcommitted
[fx][normalize] Allow for args to be left as args (pytorch#55995)
Summary: Pull Request resolved: pytorch#55995 Normalization is kind of broken currently. But making default arguments visible still appears to work, and is nice functionality to still be able to rely on/use. Adds an option to `NormalizeArgs`'s `__init__` called `normalize_to_only_use_kwargs` which defaults to true, which if set to false will keep using the same signature as provided, but additionally set kwargs in kwargs. Test Plan: Added test to `test_fx_experimental`. Reviewed By: 842974287 Differential Revision: D27759448 fbshipit-source-id: 620061fcf46d8549ac70b62aede8b6740aee3778
1 parent 3fbc154 commit 4ef8205

File tree

4 files changed

+77
-45
lines changed

4 files changed

+77
-45
lines changed

test/test_fx_experimental.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,8 @@ def forward(self, {params}):
911911
normalized_args2 = normalize_module(traced, node.target, node.args, node.kwargs)
912912
assert(normalized_args == normalized_args2)
913913
assert normalized_args
914-
node.args = ()
915-
node.kwargs = normalized_args
914+
node.args = normalized_args.args
915+
node.kwargs = normalized_args.kwargs
916916

917917
traced.recompile()
918918

@@ -1306,8 +1306,8 @@ def jit_infer_type(v):
13061306
continue
13071307
# Test normalize_function by itself
13081308
ref_out = op.op(*arg_values, **kwarg_values)
1309-
norm_kwargs = normalize_function(op.op, arg_values, kwarg_values, arg_types, kwarg_types)
1310-
test_out = op.op(**norm_kwargs)
1309+
norm_args_and_kwargs = normalize_function(op.op, arg_values, kwarg_values, arg_types, kwarg_types)
1310+
test_out = op.op(*norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs)
13111311
self.assertEqual(test_out, ref_out)
13121312

13131313
# Test normalized_arguments as part of FX
@@ -1351,8 +1351,8 @@ def forward(self, {', '.join(param_names)}):
13511351
if node.op == 'call_function':
13521352
normalized_args = node.normalized_arguments(traced, arg_types, kwarg_types)
13531353
assert normalized_args
1354-
node.args = ()
1355-
node.kwargs = normalized_args
1354+
node.args = normalized_args.args
1355+
node.kwargs = normalized_args.kwargs
13561356
traced.recompile()
13571357

13581358
test_out = traced(*param_values)

torch/fx/experimental/normalize.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ class NormalizeArgs(Transformer):
1313
"""
1414
Normalize arguments to Python targets. This means that
1515
`args/kwargs` will be matched up to the module/functional's
16-
signature and rewritten to exclusively kwargs in positional order.
17-
Also populates default values. Does not support positional-only
18-
parameters or varargs parameters (*args, **kwargs).
16+
signature and rewritten to exclusively kwargs in positional order
17+
if `normalize_to_only_use_kwargs` is true. Also populates default
18+
values. Does not support positional-only parameters or varargs
19+
parameters (*args, **kwargs).
1920
2021
If the nodes have 'type' metadata, it will use it to disambiguate
2122
overloads. Otherwise, it will throw an error.
@@ -25,9 +26,11 @@ class NormalizeArgs(Transformer):
2526
traced = torch.fx.symbolic_trace(m)
2627
traced = NormalizeArgs(traced).transform()
2728
"""
28-
def __init__(self, module : torch.nn.Module):
29+
def __init__(self, module : torch.nn.Module,
30+
normalize_to_only_use_kwargs : bool = True):
2931
super().__init__(module)
3032
self.node_map: Dict[Proxy, Node] = {}
33+
self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
3134

3235
def run_node(self, n: Node) -> Any:
3336
args, kwargs = self.fetch_args_kwargs_from_env(n)
@@ -51,17 +54,21 @@ def call_function(
5154
self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any],
5255
arg_types: Optional[Tuple[Any, ...]] = None, kwarg_types : Optional[Dict[str, Any]] = None):
5356
assert callable(target)
54-
new_kwargs = normalize_function(target, args, kwargs, arg_types, kwarg_types) # type: ignore[arg-type]
55-
if new_kwargs:
56-
return self.tracer.create_proxy('call_function', target, (), new_kwargs)
57+
new_args_and_kwargs = normalize_function(target, args, kwargs, arg_types, kwarg_types, # type: ignore[arg-type]
58+
self.normalize_to_only_use_kwargs)
59+
if new_args_and_kwargs:
60+
new_args, new_kwargs = new_args_and_kwargs
61+
return self.tracer.create_proxy('call_function', target, new_args, new_kwargs)
5762
else:
5863
return super().call_function(target, args, kwargs)
5964

6065
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
6166
assert isinstance(target, str)
62-
new_kwargs = normalize_module(self.module, target, args, kwargs) # type: ignore[arg-type]
63-
if new_kwargs:
64-
return super().call_module(target, (), new_kwargs)
67+
new_args_and_kwargs = normalize_module(self.module, target, args, kwargs, # type: ignore[arg-type]
68+
self.normalize_to_only_use_kwargs)
69+
if new_args_and_kwargs:
70+
new_args, new_kwargs = new_args_and_kwargs
71+
return super().call_module(target, new_args, new_kwargs)
6572
else:
6673
return super().call_module(target, args, kwargs)
6774

torch/fx/node.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import builtins
66
import types
7-
from torch.fx.operator_schemas import normalize_function, normalize_module
7+
from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
88

99
if TYPE_CHECKING:
1010
from .graph import Graph
@@ -446,11 +446,13 @@ def is_impure(self):
446446

447447
def normalized_arguments(
448448
self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None,
449-
kwarg_types : Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
449+
kwarg_types : Optional[Dict[str, Any]] = None,
450+
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
450451
"""
451452
Returns normalized arguments to Python targets. This means that
452453
`args/kwargs` will be matched up to the module/functional's
453-
signature and return exclusively kwargs in positional order.
454+
signature and return exclusively kwargs in positional order
455+
if `normalize_to_only_use_kwargs` is true.
454456
Also populates default values. Does not support positional-only
455457
parameters or varargs parameters.
456458
@@ -462,10 +464,11 @@ def normalized_arguments(
462464
root (torch.nn.Module): Module upon which to resolve module targets.
463465
arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
464466
kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
467+
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
465468
466469
Returns:
467470
468-
Returns normalized_kwargs, or `None` if not successful.
471+
Returns NamedTuple ArgsKwargsPair, or `None` if not successful.
469472
"""
470473
if self.op == 'call_function':
471474
assert callable(self.target)

torch/fx/operator_schemas.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,16 @@
33
import numbers
44
import typing
55
import enum
6-
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast
77
from torch._jit_internal import boolean_dispatched
88

9+
class ArgsKwargsPair(NamedTuple):
10+
"""
11+
Simple named tuple for wrapping args/kwargs pairs.
12+
"""
13+
args: Tuple[Any, ...]
14+
kwargs: Dict[str, Any]
15+
916
_manual_overrides : Dict[Callable, List[inspect.Signature]] = {}
1017

1118
def _nonzero_schemas():
@@ -140,11 +147,13 @@ def is_homogeneous_int_tuple(t):
140147

141148
def normalize_function(
142149
target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None,
143-
kwarg_types : Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
150+
kwarg_types : Optional[Dict[str, Any]] = None,
151+
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
144152
"""
145153
Returns normalized arguments to PyTorch functions. This means that
146154
`args/kwargs` will be matched up to the functional's
147-
signature and return exclusively kwargs in positional order.
155+
signature and return exclusively kwargs in positional order if
156+
`normalize_to_only_use_kwargs` is True.
148157
Also populates default values. Does not support positional-only
149158
parameters or varargs parameters (*args, **kwargs). Does not support modules.
150159
@@ -156,14 +165,15 @@ def normalize_function(
156165
kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
157166
arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
158167
kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
168+
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
159169
160170
Returns:
161171
162-
Returns normalized_kwargs, or `None` if not successful.
172+
Returns normalized_args_and_kwargs, or `None` if not successful.
163173
"""
164174
if kwargs is None:
165175
kwargs = {}
166-
new_kwargs = None
176+
new_args_and_kwargs = None
167177
if target in boolean_dispatched or target.__module__ in ['torch.nn.functional', 'torch.functional']:
168178
target_for_analysis = target
169179
if target in boolean_dispatched:
@@ -180,15 +190,15 @@ def normalize_function(
180190

181191
assert callable(target_for_analysis)
182192
sig = inspect.signature(inspect.unwrap(target_for_analysis))
183-
new_kwargs = _args_kwargs_to_normalized_kwargs(sig, args, kwargs)
193+
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs)
184194
else:
185195
assert callable(target)
186196
torch_op_schemas = get_signature_for_torch_op(target)
187197
matched_schemas = []
188198
if torch_op_schemas:
189199
# Iterate through all of the schema until we find one that matches
190-
# If one matches, populate `new_kwargs` with the combined args/kwargs
191-
# values. If none matches, `new_kwargs` will be None
200+
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
201+
# values. If none matches, `new_args_and_kwargs` will be None
192202
for candidate_signature in torch_op_schemas:
193203
try:
194204
candidate_signature.bind(*args, **kwargs)
@@ -201,7 +211,8 @@ def normalize_function(
201211
pass
202212
elif len(matched_schemas) == 1:
203213
# Matched exactly one schema, unambiguous
204-
new_kwargs = _args_kwargs_to_normalized_kwargs(matched_schemas[0], args, kwargs)
214+
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs,
215+
normalize_to_only_use_kwargs)
205216
else:
206217
if arg_types is not None or kwarg_types is not None:
207218
arg_types = arg_types if arg_types else cast(Tuple[Any], ())
@@ -216,7 +227,8 @@ def normalize_function(
216227
except TypeError as e:
217228
sig_matches = False
218229
if sig_matches:
219-
new_kwargs = _args_kwargs_to_normalized_kwargs(candidate_signature, args, kwargs)
230+
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs,
231+
normalize_to_only_use_kwargs)
220232
break
221233
else:
222234
# Matched more than one schema. In this situation, the caller must provide the types of
@@ -226,14 +238,16 @@ def normalize_function(
226238
f'the schema match was ambiguous! Please provide argument types to '
227239
f'the normalize_arguments() call. Available schemas:\n{schema_printouts}')
228240

229-
return new_kwargs
241+
return new_args_and_kwargs
230242

231243
def normalize_module(
232-
root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
244+
root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None,
245+
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
233246
"""
234247
Returns normalized arguments to PyTorch modules. This means that
235248
`args/kwargs` will be matched up to the functional's
236-
signature and return exclusively kwargs in positional order.
249+
signature and return exclusively kwargs in positional order if
250+
`normalize_to_only_use_kwargs` is True.
237251
Also populates default values. Does not support positional-only
238252
parameters or varargs parameters (*args, **kwargs).
239253
@@ -242,10 +256,11 @@ def normalize_module(
242256
target (Callable): Function that we are normalizing
243257
args (Tuple[Any]): Tuple of args to the function
244258
kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
259+
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
245260
246261
Returns:
247262
248-
Returns normalized_kwargs, or `None` if not successful.
263+
Returns normalized_args_and_kwargs, or `None` if not successful.
249264
"""
250265
try:
251266
submod = root.get_submodule(target)
@@ -258,27 +273,30 @@ def normalize_module(
258273
sig = inspect.signature(inspect.unwrap(submod.forward))
259274
if kwargs is None:
260275
kwargs = {}
261-
new_kwargs = _args_kwargs_to_normalized_kwargs(sig, args, kwargs)
262-
return new_kwargs
276+
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs,
277+
normalize_to_only_use_kwargs)
278+
return new_args_and_kwargs
263279
return None
264280

265-
def _args_kwargs_to_normalized_kwargs(sig : inspect.Signature, args : Tuple[Any, ...],
266-
kwargs : Dict[str, Any]) -> Optional[Dict[str, Any]]:
281+
def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...],
282+
kwargs : Dict[str, Any],
283+
normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]:
267284
"""
268285
Given a call target, args, and kwargs, return the arguments normalized into
269-
a single kwargs dict, or None if the type signature is not supported by
286+
an ArgsKwargsPair, or None if the type signature is not supported by
270287
this normalization.
271288
272289
Args:
273290
274291
target (inspect.Signature): Signature object for the target
275292
args (Tuple): Arguments that appear at the callsite for `target`
276293
kwargs (Dict): Keyword arugments that appear at the callsite for `target`
294+
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
277295
278296
Returns:
279297
280-
Optional[Dict]: Normalized kwargs for `target`, or `None` if this target is not
281-
supported
298+
Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if
299+
this target is not supported.
282300
"""
283301

284302
# Don't currently support positional-only
@@ -292,7 +310,11 @@ def _args_kwargs_to_normalized_kwargs(sig : inspect.Signature, args : Tuple[Any,
292310
bound_args.apply_defaults()
293311

294312
new_kwargs : Dict[str, Any] = {}
295-
for param in sig.parameters:
296-
new_kwargs[param] = bound_args.arguments[param]
297-
298-
return new_kwargs
313+
new_args : List[Any] = []
314+
for i, param in enumerate(sig.parameters):
315+
if not normalize_to_only_use_kwargs and i < len(args):
316+
new_args.append(bound_args.arguments[param])
317+
else:
318+
new_kwargs[param] = bound_args.arguments[param]
319+
320+
return ArgsKwargsPair(tuple(new_args), new_kwargs)

0 commit comments

Comments
 (0)