Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add op(upsample nearst1/2/3d) | feat(torchlib) #1246

Merged
merged 17 commits into from
Feb 5, 2024
22 changes: 13 additions & 9 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2396,12 +2396,12 @@
raise NotImplementedError()


@torch_op("aten::upsample_nearest1d", trace_only=True)
def aten_upsample_nearest1d(
self: TensorType, output_size: INT64, scales: Optional[float] = None
) -> TensorType:
self: TReal, size: INT64, scale_factor: Optional[float] = None
) -> TReal:
"""upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor"""

raise NotImplementedError()
return _aten_upsample_nearest2d_onnx(self, size)


def aten_upsample_nearest1d_backward(
Expand Down Expand Up @@ -2438,7 +2438,9 @@
size: INT64,
) -> TReal:
self_shape = op.Shape(self)
batch_channel = self_shape[:2] # type: ignore[index]
starts = op.Constant(value_ints=[0])
ends = op.Constant(value_ints=[2])
batch_channel = op.Slice(self_shape, starts, ends)
output_size = op.Concat(batch_channel, size, axis=0)

return op.Resize(
Expand Down Expand Up @@ -2466,16 +2468,18 @@
raise NotImplementedError()


@torch_op("aten::upsample_nearest3d", trace_only=True)
def aten_upsample_nearest3d(
self: TensorType,
output_size: INT64,
self: TReal,
size: INT64,
scales_d: Optional[float] = None,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TensorType:
) -> TReal:
"""upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor"""

raise NotImplementedError()
return _aten_upsample_nearest2d_onnx(self, size)

Check warning on line 2481 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L2481

Added line #L2481 was not covered by tests
xiaowuhu marked this conversation as resolved.
Show resolved Hide resolved



def aten_upsample_nearest3d_backward(
Expand Down
1 change: 1 addition & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@
),
kwargs=repr(cpu_sample.kwargs),
):
if i != 0: continue
Fixed Show fixed Hide fixed
test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample, dtype)

with ops_test_common.normal_xfail_skip_test_behaviors(test_behavior, reason):
Expand Down
18 changes: 18 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2141,6 +2141,15 @@ def _where_input_wrangler(
and sample.kwargs.get("scales") is not None,
reason="fixme: align_corners=False output mismatch when scales are provided",
),
TorchLibOpInfo(
"nn.functional.upsample_nearest1d",
nn_ops.aten_upsample_nearest1d,
trace_only=True,
).skip(
# size parameter must be existed
matcher=lambda sample: sample.kwargs.get("size", None) is None,
reason="aten_upsample_nearest1d takes size as input",
),
TorchLibOpInfo(
"nn.functional.upsample_nearest2d",
nn_ops.aten_upsample_nearest2d,
Expand All @@ -2156,6 +2165,15 @@ def _where_input_wrangler(
matcher=lambda sample: "scale_factor" in sample.kwargs,
reason="fixme: the scale_factor tests",
),
TorchLibOpInfo(
"nn.functional.upsample_nearest3d",
nn_ops.aten_upsample_nearest3d,
trace_only=True,
).skip(
# Shape should be [N, C, H, W]
matcher=lambda sample: len(sample.input.shape) != 2 + 3,
reason="only test on 2d inputs",
),
TorchLibOpInfo("ones_like", core_ops.aten_ones_like, trace_only=True),
TorchLibOpInfo(
"roll",
Expand Down
Loading