diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index d018c5d51..bb767071e 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2197,16 +2197,86 @@ def aten_unflatten_dense_tensors( raise NotImplementedError() +@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bicubic2d.vec"), trace_only=True) def aten_upsample_bicubic2d( - self: TensorType, + self: TReal, output_size: INT64, align_corners: bool, - scales_h: Optional[float] = None, - scales_w: Optional[float] = None, -) -> TensorType: - """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + scale_factors: Optional[TFloat] = None, +) -> TReal: + """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + """ - raise NotImplementedError() + if output_size is not None: + result = _aten_upsample_output_size(self, output_size, align_corners, "cubic") + else: + result = _aten_upsample_scales(self, scale_factors, align_corners, "cubic") + return result + + +@torch_op("aten::upsample_bicubic2d", private=True) +def _aten_upsample_output_size( + self: TReal, + output_size: INT64, + align_corners: bool, + str_mode: str, +) -> TReal: + self_shape = op.Shape(self) + starts = op.Constant(value_ints=[0]) + ends = op.Constant(value_ints=[2]) + batch_channel = op.Slice(self_shape, starts, ends) + output_size = op.Concat(batch_channel, output_size, axis=0) + if align_corners: + result = op.Resize( + self, + None, + None, + output_size, + mode=str_mode, + coordinate_transformation_mode="align_corners", + ) + else: + result = op.Resize( + self, + None, + None, + output_size, + mode=str_mode, + coordinate_transformation_mode="pytorch_half_pixel", + ) + + return result + + +@torch_op("aten::upsample_bicubic2d", private=True) +def _aten_upsample_scales( + self: TReal, + scale_factors: TFloat, + align_corners: bool, + str_mode: str, +) -> TReal: + scale_factors = op.Cast(scale_factors, to=FLOAT.dtype) + scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0) + if align_corners: + result = op.Resize( + self, + None, + scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] + None, + mode=str_mode, + coordinate_transformation_mode="align_corners", + ) + else: + result = op.Resize( + self, + None, + scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] + None, + mode=str_mode, + coordinate_transformation_mode="pytorch_half_pixel", + ) + return result def aten_upsample_bicubic2d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 24d962c3d..0224b6cfa 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -1409,6 +1409,57 @@ def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(t, args=(dimension, size, step)) +def sample_inputs_upsample_bicubic2d(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + align_corners_options = (True, False) + rank = 2 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) + + for align_corners in align_corners_options: + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(S, rank, False), align_corners + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + shape(L, rank, False), + align_corners, + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, # output_size + align_corners, + (1.7, 1.7), # scaler + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, # if this is None, the scalar must be list + align_corners, + (0.6, 0.6), + ) + + class _TestParamsMaxPoolEmptyStrideBase: # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203 def __init__(self): @@ -1874,6 +1925,13 @@ def __init__(self): sample_inputs_func=sample_inputs_unfold, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_bicubic2d", + aten_name="upsample_bicubic2d", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_bicubic2d, + supports_out=False, + ), opinfo_core.OpInfo( "nn.functional.max_pool1d_with_indices", aten_name="max_pool1d_with_indices", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 655e8809a..7aeba0d14 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -2122,6 +2122,11 @@ def _where_input_wrangler( input_wrangler=_upsample_bilinear2d_input_wrangler, trace_only=True, ), + TorchLibOpInfo( + "ops.aten.upsample_bicubic2d", + nn_ops.aten_upsample_bicubic2d, + trace_only=True, + ), TorchLibOpInfo( "nn.functional.upsample_nearest2d", nn_ops.aten_upsample_nearest2d,