Skip to content

Commit

Permalink
Fix signature to match aten spec; Fix wrangler to bridge from source …
Browse files Browse the repository at this point in the history
…function torch.nn.functional.upsample_bilinear
  • Loading branch information
BowenBao committed Jan 16, 2024
1 parent 2cb4245 commit 7d14542
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
13 changes: 7 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2295,10 +2295,10 @@ def aten_upsample_bicubic2d_backward(
@torch_op("aten::upsample_bilinear2d", trace_only=True)
def aten_upsample_bilinear2d(
self: TReal,
output_size: Optional[INT64] = None,
output_size: Optional[INT64],
align_corners: bool,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
align_corners: bool = False,
) -> TReal:
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""

Expand All @@ -2319,13 +2319,14 @@ def aten_upsample_bilinear2d(
@torch_op("aten::upsample_bilinear2d.vec", trace_only=True)
def aten_upsample_bilinear2d_vec(
self: TReal,
output_size: Optional[INT64] = None,
align_corners: bool = False,
scale_factors: Optional[Sequence[float]] = None,
output_size: Optional[INT64],
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> TReal:
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""
scales_h = scale_factors[0] if scale_factors is not None else None
scales_w = scale_factors[1] if scale_factors is not None else None
return aten_upsample_bilinear2d(self, output_size, scales_h, scales_w, align_corners)
return aten_upsample_bilinear2d(self, output_size, align_corners, scales_h, scales_w)


@torch_op("aten::upsample_bilinear2d", private=True)
Expand Down
30 changes: 28 additions & 2 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,21 @@ def _sum_input_wrangler(
def _upsample_bilinear2d_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Wrangler for the signature difference between
# 'nn.functional.upsample_bilinear'
# and
# 'aten::upsample_bilinear2d'
# https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html
if "size" in kwargs:
args.append(np.array(kwargs["size"], dtype=np.int64))
del kwargs["size"] # promote tensor type kwargs to args
else:

Check warning on line 428 in onnxscript/tests/function_libs/torch_lib/ops_test_data.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tests/function_libs/torch_lib/ops_test_data.py#L428

Added line #L428 was not covered by tests
args.append(None)
if "align_corners" in kwargs:

Check warning on line 430 in onnxscript/tests/function_libs/torch_lib/ops_test_data.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tests/function_libs/torch_lib/ops_test_data.py#L430

Added line #L430 was not covered by tests
args.append(kwargs["align_corners"])
del kwargs["align_corners"]
else:

Check warning on line 433 in onnxscript/tests/function_libs/torch_lib/ops_test_data.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tests/function_libs/torch_lib/ops_test_data.py#L432-L433

Added lines #L432 - L433 were not covered by tests
args.append(True) # Fill in the default value
if "scale_factor" in kwargs:
kwargs["scales_h"] = kwargs["scale_factor"]
kwargs["scales_w"] = kwargs["scale_factor"]
Expand All @@ -430,12 +442,26 @@ def _upsample_bilinear2d_input_wrangler(
def _upsample_bilinear2d_vec_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Wrangler for the signature difference between
# 'nn.functional.upsample_bilinear'
# and
# 'aten::upsample_bilinear2d.vec'
# https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html
if "size" in kwargs:
args.append(np.array(kwargs["size"], dtype=np.int64))
del kwargs["size"] # promote tensor type kwargs to args
else:
args.append(None)
if "align_corners" in kwargs:
args.append(kwargs["align_corners"])
del kwargs["align_corners"]
else:

Check warning on line 458 in onnxscript/tests/function_libs/torch_lib/ops_test_data.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tests/function_libs/torch_lib/ops_test_data.py#L457-L458

Added lines #L457 - L458 were not covered by tests
args.append(True) # Fill in the default value
if "scale_factor" in kwargs:
kwargs["scale_factors"] = [kwargs["scale_factor"]] * 2
del kwargs["scale_factor"] # adapt the function signature
args.append([kwargs["scale_factor"]] * 2)
del kwargs["scale_factor"] # promote tensor type kwargs to args
else:
args.append(None)
return args, kwargs


Expand Down

0 comments on commit 7d14542

Please sign in to comment.