Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add op(upsample nearst1/2/3d) | feat(torchlib) #1246

Merged
merged 17 commits into from
Feb 5, 2024
76 changes: 44 additions & 32 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2201,7 +2201,16 @@ def _get_upsample_align_corners_mode(align_corners: bool) -> str:
return "align_corners" if align_corners else "pytorch_half_pixel"


@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True)
@torch_op(
(
"aten::upsample_bicubic2d",
"aten::upsample_bilinear2d",
"aten::upsample_nearest1d",
"aten::upsample_nearest2d",
"aten::upsample_nearest3d",
),
private=True,
)
def _aten_upsample_output_size(
self: TReal,
output_size: INT64,
Expand Down Expand Up @@ -2240,7 +2249,6 @@ def _aten_upsample_scales(
None,
mode=mode,
coordinate_transformation_mode=coordinate_transformation_mode,
nearest_mode="floor",
)


Expand Down Expand Up @@ -2396,12 +2404,33 @@ def aten_upsample_linear1d_backward(
raise NotImplementedError()


@torch_op("aten::upsample_nearest1d", trace_only=True)
def aten_upsample_nearest1d(
self: TensorType, output_size: INT64, scales: Optional[float] = None
) -> TensorType:
self: TReal, size: INT64, scale_factor: Optional[float] = None
) -> TReal:
"""upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor"""
if size is not None:
return _aten_upsample_output_size(self, size, "nearest", "asymmetric")
else:
return _aten_upsample_nearest1d_scales(self, scale_factor)

raise NotImplementedError()

@torch_op("aten::upsample_nearest1d", private=True)
def _aten_upsample_nearest1d_scales(
self: TReal,
scale_factors: TFloat,
) -> TReal:
scale_factors = op.Cast(scale_factors, to=FLOAT.dtype)
scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0)
return op.Resize(
self,
None,
scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
None,
mode="nearest",
coordinate_transformation_mode="asymmetric",
nearest_mode="floor",
)


def aten_upsample_nearest1d_backward(
Expand Down Expand Up @@ -2429,29 +2458,7 @@ def aten_upsample_nearest2d(
del scales_h
del scales_w

return _aten_upsample_nearest2d_onnx(self, size)


@torch_op("aten::upsample_nearest2d", private=True)
def _aten_upsample_nearest2d_onnx(
self: TReal,
size: INT64,
) -> TReal:
self_shape = op.Shape(self)
batch_channel = self_shape[:2] # type: ignore[index]
output_size = op.Concat(batch_channel, size, axis=0)

return op.Resize(
self,
None,
None,
output_size,
mode="nearest",
# NOTE(justinchuby): Both asymmetric and pytorch_half_pixel pass the test
# I used asymmetric because it aligns with the torch.onnx exporter
coordinate_transformation_mode="asymmetric",
nearest_mode="floor",
)
return _aten_upsample_output_size(self, size, "nearest", "asymmetric")


def aten_upsample_nearest2d_backward(
Expand All @@ -2466,16 +2473,21 @@ def aten_upsample_nearest2d_backward(
raise NotImplementedError()


@torch_op("aten::upsample_nearest3d", trace_only=True)
def aten_upsample_nearest3d(
self: TensorType,
output_size: INT64,
self: TReal,
size: INT64,
scales_d: Optional[float] = None,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TensorType:
) -> TReal:
"""upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor"""

raise NotImplementedError()
del scales_h
del scales_w
del scales_d

return _aten_upsample_output_size(self, size, "nearest", "asymmetric")


def aten_upsample_nearest3d_backward(
Expand Down
164 changes: 164 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,149 @@ def shape(size, rank, with_batch_channel=True):
)


def sample_inputs_upsample_nearest1d(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

N, C = 2, 3
D = 4
SS = 3
L = 5

rank = 1

def shape(size, rank, with_batch_channel=True):
if with_batch_channel:
return tuple([N, C] + ([size] * rank))
return tuple([size] * rank)

make_arg = functools.partial(
torch_testing.make_tensor,
device=device,
dtype=dtype,
requires_grad=requires_grad,
low=-1,
high=1,
)

yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True)

yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(S, rank, False),
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(L, rank, False),
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
None, # output_size
(1.7,), # scaler
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
None, # if this is None, the scalar must be list
(0.6,),
)


def sample_inputs_upsample_nearest2d(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

N, C = 2, 3
D = 4
SS = 3
L = 5

rank = 2

def shape(size, rank, with_batch_channel=True):
if with_batch_channel:
return tuple([N, C] + ([size] * rank))
return tuple([size] * rank)

make_arg = functools.partial(
torch_testing.make_tensor,
device=device,
dtype=dtype,
requires_grad=requires_grad,
low=-1,
high=1,
)

yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True)

yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(S, rank, False),
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(L, rank, False),
)
# ONNX don't support below cases: both output_size and scaler are not None
# yield opinfo_core.SampleInput(
# make_arg(shape(D, rank)),
# shape(L, rank, False),
# 1.7, # scaler
# )
# yield opinfo_core.SampleInput(
# make_arg(shape(D, rank)),
# shape(L, rank, False),
# 0.6,
# )


def sample_inputs_upsample_nearest3d(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

N, C = 2, 3
D = 4
SS = 3
L = 5

rank = 3

def shape(size, rank, with_batch_channel=True):
if with_batch_channel:
return tuple([N, C] + ([size] * rank))
return tuple([size] * rank)

make_arg = functools.partial(
torch_testing.make_tensor,
device=device,
dtype=dtype,
requires_grad=requires_grad,
low=-1,
high=1,
)

yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True)

yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(S, rank, False),
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(L, rank, False),
)
# ONNX don't support below cases: both output_size and scaler are not None
# yield opinfo_core.SampleInput(
# make_arg(shape(D, rank)),
# shape(L, rank, False),
# 1.7, # scaler
# )
# yield opinfo_core.SampleInput(
# make_arg(shape(D, rank)),
# shape(L, rank, False),
# 0.6,
# )


def sample_inputs_upsample_trilinear3d(op_info, device, dtype, requires_grad, **kwargs):
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
del op_info
del kwargs
Expand Down Expand Up @@ -2117,6 +2260,27 @@ def __init__(self):
sample_inputs_func=sample_inputs_upsample_linear1d,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.upsample_nearest1d",
aten_name="upsample_nearest1d",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_upsample_nearest1d,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.upsample_nearest2d",
aten_name="upsample_nearest2d",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_upsample_nearest2d,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.upsample_nearest3d",
aten_name="upsample_nearest3d",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_upsample_nearest3d,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.upsample_trilinear3d",
aten_name="upsample_trilinear3d",
Expand Down
47 changes: 13 additions & 34 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,18 +415,6 @@ def _sum_input_wrangler(
return args, kwargs


def _upsample_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "scale_factor" in kwargs:
kwargs["scales_h"] = kwargs["scale_factor"]
kwargs["scales_w"] = kwargs["scale_factor"]
del kwargs["scale_factor"]
if "size" in kwargs:
kwargs["size"] = np.array(kwargs["size"], dtype=np.int64)
return args, kwargs


def _unflatten_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -2141,24 +2129,24 @@ def _where_input_wrangler(
reason="fixme: align_corners=False output mismatch when scales are provided",
),
TorchLibOpInfo(
"ops.aten.upsample_trilinear3d",
nn_ops.aten_upsample_trilinear3d,
"ops.aten.upsample_nearest1d",
nn_ops.aten_upsample_nearest1d,
trace_only=True,
),
TorchLibOpInfo(
"nn.functional.upsample_nearest2d",
"ops.aten.upsample_nearest2d",
nn_ops.aten_upsample_nearest2d,
input_wrangler=_upsample_input_wrangler,
trace_only=True,
)
.skip(
# Shape should be [N, C, H, W]
matcher=lambda sample: len(sample.input.shape) != 2 + 2,
reason="only test on 2d inputs",
)
.xfail(
matcher=lambda sample: "scale_factor" in sample.kwargs,
reason="fixme: the scale_factor tests",
),
TorchLibOpInfo(
"ops.aten.upsample_nearest3d",
nn_ops.aten_upsample_nearest3d,
trace_only=True,
),
TorchLibOpInfo(
"ops.aten.upsample_trilinear3d",
nn_ops.aten_upsample_trilinear3d,
trace_only=True,
),
TorchLibOpInfo("ones_like", core_ops.aten_ones_like, trace_only=True),
TorchLibOpInfo(
Expand Down Expand Up @@ -2376,15 +2364,6 @@ def _where_input_wrangler(
"nn.functional.celu",
("nn.functional.celu_type_promoted",),
)
ops_test_common.duplicate_opinfo(
OPS_DB,
"nn.functional.upsample_nearest",
(
"nn.functional.upsample_nearest1d",
"nn.functional.upsample_nearest2d",
"nn.functional.upsample_nearest3d",
),
)
ops_test_common.duplicate_opinfo(
OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",)
)
Expand Down
Loading