3
3
import numbers
4
4
import typing
5
5
import enum
6
- from typing import Any , Callable , Dict , List , Optional , Tuple , cast
6
+ from typing import Any , Callable , Dict , List , Optional , Tuple , NamedTuple , cast
7
7
from torch ._jit_internal import boolean_dispatched
8
8
9
+ class ArgsKwargsPair (NamedTuple ):
10
+ """
11
+ Simple named tuple for wrapping args/kwargs pairs.
12
+ """
13
+ args : Tuple [Any , ...]
14
+ kwargs : Dict [str , Any ]
15
+
9
16
_manual_overrides : Dict [Callable , List [inspect .Signature ]] = {}
10
17
11
18
def _nonzero_schemas ():
@@ -140,11 +147,13 @@ def is_homogeneous_int_tuple(t):
140
147
141
148
def normalize_function (
142
149
target : Callable , args : Tuple [Any ], kwargs : Optional [Dict [str , Any ]] = None , arg_types : Optional [Tuple [Any ]] = None ,
143
- kwarg_types : Optional [Dict [str , Any ]] = None ) -> Optional [Dict [str , Any ]]:
150
+ kwarg_types : Optional [Dict [str , Any ]] = None ,
151
+ normalize_to_only_use_kwargs : bool = False ) -> Optional [ArgsKwargsPair ]:
144
152
"""
145
153
Returns normalized arguments to PyTorch functions. This means that
146
154
`args/kwargs` will be matched up to the functional's
147
- signature and return exclusively kwargs in positional order.
155
+ signature and return exclusively kwargs in positional order if
156
+ `normalize_to_only_use_kwargs` is True.
148
157
Also populates default values. Does not support positional-only
149
158
parameters or varargs parameters (*args, **kwargs). Does not support modules.
150
159
@@ -156,14 +165,15 @@ def normalize_function(
156
165
kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
157
166
arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
158
167
kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
168
+ normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
159
169
160
170
Returns:
161
171
162
- Returns normalized_kwargs , or `None` if not successful.
172
+ Returns normalized_args_and_kwargs , or `None` if not successful.
163
173
"""
164
174
if kwargs is None :
165
175
kwargs = {}
166
- new_kwargs = None
176
+ new_args_and_kwargs = None
167
177
if target in boolean_dispatched or target .__module__ in ['torch.nn.functional' , 'torch.functional' ]:
168
178
target_for_analysis = target
169
179
if target in boolean_dispatched :
@@ -180,15 +190,15 @@ def normalize_function(
180
190
181
191
assert callable (target_for_analysis )
182
192
sig = inspect .signature (inspect .unwrap (target_for_analysis ))
183
- new_kwargs = _args_kwargs_to_normalized_kwargs (sig , args , kwargs )
193
+ new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs (sig , args , kwargs , normalize_to_only_use_kwargs )
184
194
else :
185
195
assert callable (target )
186
196
torch_op_schemas = get_signature_for_torch_op (target )
187
197
matched_schemas = []
188
198
if torch_op_schemas :
189
199
# Iterate through all of the schema until we find one that matches
190
- # If one matches, populate `new_kwargs ` with the combined args/kwargs
191
- # values. If none matches, `new_kwargs ` will be None
200
+ # If one matches, populate `new_args_and_kwargs ` with the new args/kwargs
201
+ # values. If none matches, `new_args_and_kwargs ` will be None
192
202
for candidate_signature in torch_op_schemas :
193
203
try :
194
204
candidate_signature .bind (* args , ** kwargs )
@@ -201,7 +211,8 @@ def normalize_function(
201
211
pass
202
212
elif len (matched_schemas ) == 1 :
203
213
# Matched exactly one schema, unambiguous
204
- new_kwargs = _args_kwargs_to_normalized_kwargs (matched_schemas [0 ], args , kwargs )
214
+ new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs (matched_schemas [0 ], args , kwargs ,
215
+ normalize_to_only_use_kwargs )
205
216
else :
206
217
if arg_types is not None or kwarg_types is not None :
207
218
arg_types = arg_types if arg_types else cast (Tuple [Any ], ())
@@ -216,7 +227,8 @@ def normalize_function(
216
227
except TypeError as e :
217
228
sig_matches = False
218
229
if sig_matches :
219
- new_kwargs = _args_kwargs_to_normalized_kwargs (candidate_signature , args , kwargs )
230
+ new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs (candidate_signature , args , kwargs ,
231
+ normalize_to_only_use_kwargs )
220
232
break
221
233
else :
222
234
# Matched more than one schema. In this situation, the caller must provide the types of
@@ -226,14 +238,16 @@ def normalize_function(
226
238
f'the schema match was ambiguous! Please provide argument types to '
227
239
f'the normalize_arguments() call. Available schemas:\n { schema_printouts } ' )
228
240
229
- return new_kwargs
241
+ return new_args_and_kwargs
230
242
231
243
def normalize_module (
232
- root : torch .nn .Module , target : str , args : Tuple [Any ], kwargs : Optional [Dict [str , Any ]] = None ) -> Optional [Dict [str , Any ]]:
244
+ root : torch .nn .Module , target : str , args : Tuple [Any ], kwargs : Optional [Dict [str , Any ]] = None ,
245
+ normalize_to_only_use_kwargs : bool = False ) -> Optional [ArgsKwargsPair ]:
233
246
"""
234
247
Returns normalized arguments to PyTorch modules. This means that
235
248
`args/kwargs` will be matched up to the functional's
236
- signature and return exclusively kwargs in positional order.
249
+ signature and return exclusively kwargs in positional order if
250
+ `normalize_to_only_use_kwargs` is True.
237
251
Also populates default values. Does not support positional-only
238
252
parameters or varargs parameters (*args, **kwargs).
239
253
@@ -242,10 +256,11 @@ def normalize_module(
242
256
target (Callable): Function that we are normalizing
243
257
args (Tuple[Any]): Tuple of args to the function
244
258
kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
259
+ normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
245
260
246
261
Returns:
247
262
248
- Returns normalized_kwargs , or `None` if not successful.
263
+ Returns normalized_args_and_kwargs , or `None` if not successful.
249
264
"""
250
265
try :
251
266
submod = root .get_submodule (target )
@@ -258,27 +273,30 @@ def normalize_module(
258
273
sig = inspect .signature (inspect .unwrap (submod .forward ))
259
274
if kwargs is None :
260
275
kwargs = {}
261
- new_kwargs = _args_kwargs_to_normalized_kwargs (sig , args , kwargs )
262
- return new_kwargs
276
+ new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs (sig , args , kwargs ,
277
+ normalize_to_only_use_kwargs )
278
+ return new_args_and_kwargs
263
279
return None
264
280
265
- def _args_kwargs_to_normalized_kwargs (sig : inspect .Signature , args : Tuple [Any , ...],
266
- kwargs : Dict [str , Any ]) -> Optional [Dict [str , Any ]]:
281
+ def _args_kwargs_to_normalized_args_kwargs (sig : inspect .Signature , args : Tuple [Any , ...],
282
+ kwargs : Dict [str , Any ],
283
+ normalize_to_only_use_kwargs : bool ) -> Optional [ArgsKwargsPair ]:
267
284
"""
268
285
Given a call target, args, and kwargs, return the arguments normalized into
269
- a single kwargs dict , or None if the type signature is not supported by
286
+ an ArgsKwargsPair , or None if the type signature is not supported by
270
287
this normalization.
271
288
272
289
Args:
273
290
274
291
target (inspect.Signature): Signature object for the target
275
292
args (Tuple): Arguments that appear at the callsite for `target`
276
293
kwargs (Dict): Keyword arugments that appear at the callsite for `target`
294
+ normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
277
295
278
296
Returns:
279
297
280
- Optional[Dict ]: Normalized kwargs for `target`, or `None` if this target is not
281
- supported
298
+ Optional[ArgsKwargsPair ]: Normalized args and kwargs for `target`, or `None` if
299
+ this target is not supported.
282
300
"""
283
301
284
302
# Don't currently support positional-only
@@ -292,7 +310,11 @@ def _args_kwargs_to_normalized_kwargs(sig : inspect.Signature, args : Tuple[Any,
292
310
bound_args .apply_defaults ()
293
311
294
312
new_kwargs : Dict [str , Any ] = {}
295
- for param in sig .parameters :
296
- new_kwargs [param ] = bound_args .arguments [param ]
297
-
298
- return new_kwargs
313
+ new_args : List [Any ] = []
314
+ for i , param in enumerate (sig .parameters ):
315
+ if not normalize_to_only_use_kwargs and i < len (args ):
316
+ new_args .append (bound_args .arguments [param ])
317
+ else :
318
+ new_kwargs [param ] = bound_args .arguments [param ]
319
+
320
+ return ArgsKwargsPair (tuple (new_args ), new_kwargs )
0 commit comments