Skip to content

Commit 7d14542

Browse files
committed
Fix signature to match aten spec; Fix wrangler to bridge from source function torch.nn.functional.upsample_bilinear
1 parent 2cb4245 commit 7d14542

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,10 +2295,10 @@ def aten_upsample_bicubic2d_backward(
22952295
@torch_op("aten::upsample_bilinear2d", trace_only=True)
22962296
def aten_upsample_bilinear2d(
22972297
self: TReal,
2298-
output_size: Optional[INT64] = None,
2298+
output_size: Optional[INT64],
2299+
align_corners: bool,
22992300
scales_h: Optional[float] = None,
23002301
scales_w: Optional[float] = None,
2301-
align_corners: bool = False,
23022302
) -> TReal:
23032303
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""
23042304

@@ -2319,13 +2319,14 @@ def aten_upsample_bilinear2d(
23192319
@torch_op("aten::upsample_bilinear2d.vec", trace_only=True)
23202320
def aten_upsample_bilinear2d_vec(
23212321
self: TReal,
2322-
output_size: Optional[INT64] = None,
2323-
align_corners: bool = False,
2324-
scale_factors: Optional[Sequence[float]] = None,
2322+
output_size: Optional[INT64],
2323+
align_corners: bool,
2324+
scale_factors: Optional[Sequence[float]],
23252325
) -> TReal:
2326+
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""
23262327
scales_h = scale_factors[0] if scale_factors is not None else None
23272328
scales_w = scale_factors[1] if scale_factors is not None else None
2328-
return aten_upsample_bilinear2d(self, output_size, scales_h, scales_w, align_corners)
2329+
return aten_upsample_bilinear2d(self, output_size, align_corners, scales_h, scales_w)
23292330

23302331

23312332
@torch_op("aten::upsample_bilinear2d", private=True)

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,21 @@ def _sum_input_wrangler(
417417
def _upsample_bilinear2d_input_wrangler(
418418
args: list[Any], kwargs: dict[str, Any]
419419
) -> tuple[list[Any], dict[str, Any]]:
420+
# Wrangler for the signature difference between
421+
# 'nn.functional.upsample_bilinear'
422+
# and
423+
# 'aten::upsample_bilinear2d'
424+
# https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html
420425
if "size" in kwargs:
421426
args.append(np.array(kwargs["size"], dtype=np.int64))
422427
del kwargs["size"] # promote tensor type kwargs to args
428+
else:
429+
args.append(None)
430+
if "align_corners" in kwargs:
431+
args.append(kwargs["align_corners"])
432+
del kwargs["align_corners"]
433+
else:
434+
args.append(True) # Fill in the default value
423435
if "scale_factor" in kwargs:
424436
kwargs["scales_h"] = kwargs["scale_factor"]
425437
kwargs["scales_w"] = kwargs["scale_factor"]
@@ -430,12 +442,26 @@ def _upsample_bilinear2d_input_wrangler(
430442
def _upsample_bilinear2d_vec_input_wrangler(
431443
args: list[Any], kwargs: dict[str, Any]
432444
) -> tuple[list[Any], dict[str, Any]]:
445+
# Wrangler for the signature difference between
446+
# 'nn.functional.upsample_bilinear'
447+
# and
448+
# 'aten::upsample_bilinear2d.vec'
449+
# https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html
433450
if "size" in kwargs:
434451
args.append(np.array(kwargs["size"], dtype=np.int64))
435452
del kwargs["size"] # promote tensor type kwargs to args
453+
else:
454+
args.append(None)
455+
if "align_corners" in kwargs:
456+
args.append(kwargs["align_corners"])
457+
del kwargs["align_corners"]
458+
else:
459+
args.append(True) # Fill in the default value
436460
if "scale_factor" in kwargs:
437-
kwargs["scale_factors"] = [kwargs["scale_factor"]] * 2
438-
del kwargs["scale_factor"] # adapt the function signature
461+
args.append([kwargs["scale_factor"]] * 2)
462+
del kwargs["scale_factor"] # promote tensor type kwargs to args
463+
else:
464+
args.append(None)
439465
return args, kwargs
440466

441467

0 commit comments

Comments
 (0)