@@ -417,9 +417,21 @@ def _sum_input_wrangler(
417
417
def _upsample_bilinear2d_input_wrangler (
418
418
args : list [Any ], kwargs : dict [str , Any ]
419
419
) -> tuple [list [Any ], dict [str , Any ]]:
420
+ # Wrangler for the signature difference between
421
+ # 'nn.functional.upsample_bilinear'
422
+ # and
423
+ # 'aten::upsample_bilinear2d'
424
+ # https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html
420
425
if "size" in kwargs :
421
426
args .append (np .array (kwargs ["size" ], dtype = np .int64 ))
422
427
del kwargs ["size" ] # promote tensor type kwargs to args
428
+ else :
429
+ args .append (None )
430
+ if "align_corners" in kwargs :
431
+ args .append (kwargs ["align_corners" ])
432
+ del kwargs ["align_corners" ]
433
+ else :
434
+ args .append (True ) # Fill in the default value
423
435
if "scale_factor" in kwargs :
424
436
kwargs ["scales_h" ] = kwargs ["scale_factor" ]
425
437
kwargs ["scales_w" ] = kwargs ["scale_factor" ]
@@ -430,12 +442,26 @@ def _upsample_bilinear2d_input_wrangler(
430
442
def _upsample_bilinear2d_vec_input_wrangler (
431
443
args : list [Any ], kwargs : dict [str , Any ]
432
444
) -> tuple [list [Any ], dict [str , Any ]]:
445
+ # Wrangler for the signature difference between
446
+ # 'nn.functional.upsample_bilinear'
447
+ # and
448
+ # 'aten::upsample_bilinear2d.vec'
449
+ # https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html
433
450
if "size" in kwargs :
434
451
args .append (np .array (kwargs ["size" ], dtype = np .int64 ))
435
452
del kwargs ["size" ] # promote tensor type kwargs to args
453
+ else :
454
+ args .append (None )
455
+ if "align_corners" in kwargs :
456
+ args .append (kwargs ["align_corners" ])
457
+ del kwargs ["align_corners" ]
458
+ else :
459
+ args .append (True ) # Fill in the default value
436
460
if "scale_factor" in kwargs :
437
- kwargs ["scale_factors" ] = [kwargs ["scale_factor" ]] * 2
438
- del kwargs ["scale_factor" ] # adapt the function signature
461
+ args .append ([kwargs ["scale_factor" ]] * 2 )
462
+ del kwargs ["scale_factor" ] # promote tensor type kwargs to args
463
+ else :
464
+ args .append (None )
439
465
return args , kwargs
440
466
441
467
0 commit comments