From f12a0f4b27bc034c646a2e96a4f20c07b87b2413 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Thu, 21 Dec 2023 11:46:49 -0800 Subject: [PATCH 01/13] Fully support on complex dtype | feat(graph_building) (#1238) Previous to this PR, the graph didn't support complex dtype attributes, and dim offset of cat/stack is not considered in complex dtype. Adds: (1) Support complex constant (convert to real reprentation) (2) Add cat/stack complex support NOTE: `aten_scalar_tensor_complex` only supports complex input, and if input is real to convert to complex, `aten_scalar_tensor` is invoked instead, just like we have in tests. --------- Co-authored-by: Justin Chu --- .../function_libs/torch_lib/graph_building.py | 7 ++++ .../function_libs/torch_lib/ops/core.py | 41 ++++++++++++++++++- .../function_libs/torch_lib/extra_opinfo.py | 22 ++++++++++ .../torch_lib/ops_test_common.py | 2 + .../function_libs/torch_lib/ops_test_data.py | 11 +++++ 5 files changed, 82 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index ad3e72f37..aa318d523 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -34,6 +34,7 @@ Sequence["TorchScriptTensor"], Sequence[float], Sequence[int], + complex, str, int, float, @@ -45,6 +46,7 @@ Sequence["TorchScriptTensor"], Sequence[float], Sequence[int], + complex, str, int, float, @@ -56,6 +58,7 @@ Sequence[torch.Value], Sequence[float], Sequence[int], + complex, str, int, float, @@ -654,6 +657,10 @@ def _add_constant_to_graph(self, constant) -> torch.Value: isinstance(val, float) for val in constant ): constant_tensor = torch.tensor(constant, dtype=torch.float) + elif isinstance(constant, complex): + # NOTE: ONNX doesn't support tensor of complex64/complex128, so we + # convert them to float32/float64 with real representation. + constant_tensor = torch.view_as_real(torch.tensor(constant).resolve_conj()) else: raise TypeError( f"Constant input '{constant}' of type '{type(constant)}' is not supported" diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ab37f8739..2786d6f80 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -17,6 +17,8 @@ from onnxscript import ( BFLOAT16, BOOL, + COMPLEX64, + COMPLEX128, DOUBLE, FLOAT, FLOAT16, @@ -1412,6 +1414,15 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() +@torch_op("aten::cat", trace_only=True, complex=True) +def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: + """cat(Tensor[] tensors, int dim=0) -> Tensor""" + # Real representation unsqueezes the last dimension + if dim < 0: + dim = dim - 1 + return aten_cat(tensors, dim=dim) + + @torch_op("aten::cat") def aten_cat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: """cat(Tensor[] tensors, int dim=0) -> Tensor""" @@ -7090,6 +7101,25 @@ def aten_scalar_tensor(s: float, dtype: int = FLOAT.dtype) -> RealType: return common_ops.cast_to(s, dtype=dtype) +@torch_op("aten::scalar_tensor", trace_only=True, complex=True) +def aten_scalar_tensor_complex( + s: Union[FLOAT, DOUBLE], dtype: int = COMPLEX64.dtype +) -> RealType: + """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" + # NOTE: When the input is originally in complex, this function is invoked. + # On the other hand, when the input is originally in real, aten_scalar_tensor is used. + # is invoked. + if dtype == COMPLEX128.dtype: + result = op.Cast(s, to=DOUBLE.dtype) + elif dtype == COMPLEX64.dtype: + result = op.Cast(s, to=FLOAT.dtype) + else: + # NOTE: No-op for non-complex dtype + # It's potentially a bug if it comes here with no-op. + result = s + return result + + @torch_op("aten::scalar_tensor", trace_only=True) def aten_scalar_tensor_sym_number(s: RealType, dtype: int = FLOAT.dtype) -> RealType: """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" @@ -7520,6 +7550,15 @@ def aten_sspaddmm( raise NotImplementedError() +@torch_op("aten::stack", trace_only=True, complex=True) +def aten_stack_complex(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrString: + """stack(Tensor[] tensors, int dim=0) -> Tensor""" + # Real representation unsqueezes the last dimension + if dim < 0: + dim = dim - 1 + return aten_stack(tensors, dim) + + @torch_op("aten::stack") def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrString: """stack(Tensor[] tensors, int dim=0) -> Tensor""" @@ -8088,7 +8127,7 @@ def aten_triplet_margin_loss( @torch_op("aten::triu") -def aten_triu(self: TensorType, diagonal: int = 0) -> TensorType: +def aten_triu(self: TTensor, diagonal: int = 0) -> TTensor: """triu(Tensor self, int diagonal=0) -> Tensor""" return op.Trilu(self, diagonal, upper=1) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 60b9eb4f8..24d962c3d 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -20,6 +20,20 @@ M = 10 +def sample_inputs_scalar_tensor(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + del device + del requires_grad + # Not including a scalar tensor in vals because meta tests start failing due to + # lack of meta support for _local_scalar_dense + # torch.tensor(2, device=device) + vals = (-5j, 0j, 1j) + + for item in vals: + yield opinfo_core.SampleInput(item, dtype=dtype) + + def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs): del op_info @@ -1881,4 +1895,12 @@ def __init__(self): sample_inputs_func=sample_inputs_max_pool3d_with_indices, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.scalar_tensor", + aten_name="scalar_tensor", + dtypes=common_dtype.complex_types(), + sample_inputs_func=sample_inputs_scalar_tensor, + supports_autograd=False, + supports_out=False, + ), ] diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py index bc0b08e98..172d183f2 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py @@ -263,6 +263,8 @@ def convert_tensor_to_numpy(input: Any) -> Any: # from complex to real representation input = torch.view_as_real(input) return input.detach().cpu().numpy() + if isinstance(input, complex): + return torch.view_as_real(torch.tensor(input)).detach().cpu().numpy() if isinstance(input, (tuple, list)): if len(input) == 0: return np.array((), dtype=np.int64) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 53d79a1c0..655e8809a 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -715,6 +715,10 @@ def _where_input_wrangler( matcher=lambda sample: sample.input[0].equal(torch.tensor([])), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), + TorchLibOpInfo("cat", core_ops.aten_cat_complex, trace_only=True, complex=True).skip( + matcher=lambda sample: sample.input[0].equal(torch.tensor([])), + reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", + ), TorchLibOpInfo("ceil", core_ops.aten_ceil), TorchLibOpInfo( "chunk", @@ -1427,6 +1431,12 @@ def _where_input_wrangler( trace_only=True, complex=True, ), + TorchLibOpInfo( + "ops.aten.scalar_tensor", + core_ops.aten_scalar_tensor_complex, + trace_only=True, + complex=True, + ), TorchLibOpInfo( "scatter_add", core_ops.aten_scatter_add, @@ -1541,6 +1551,7 @@ def _where_input_wrangler( reason="this Aten overload only support one tensor as input by design", ), TorchLibOpInfo("stack", core_ops.aten_stack), + TorchLibOpInfo("stack", core_ops.aten_stack_complex, complex=True, trace_only=True), TorchLibOpInfo("sub", core_ops.aten_sub), TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True, trace_only=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB From 471aa16f4753f3116213c5ae84ef5a8d40173035 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 27 Dec 2023 09:29:35 -0800 Subject: [PATCH 02/13] chore(deps): bump onnx-weekly from 1.16.0.dev20231211 to 1.16.0.dev20231225 in /requirements/ci (#1243) --- requirements/ci/requirements-onnx-weekly.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/ci/requirements-onnx-weekly.txt b/requirements/ci/requirements-onnx-weekly.txt index ab99a52ec..3dbea5df6 100644 --- a/requirements/ci/requirements-onnx-weekly.txt +++ b/requirements/ci/requirements-onnx-weekly.txt @@ -1 +1 @@ -onnx-weekly==1.16.0.dev20231211 +onnx-weekly==1.16.0.dev20231225 From 172a81745c10467e04d9b4fdfa7829ce3321b6ce Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 27 Dec 2023 09:49:25 -0800 Subject: [PATCH 03/13] chore(deps): bump ruff from 0.1.8 to 0.1.9 in /requirements/lintrunner (#1240) --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index a3e457cb4..27066d436 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.1.8 +ruff==0.1.9 # MYPY mypy==1.7.1 types-PyYAML==6.0.12.11 From 1fa1ed64f50bcbdd570604a3d47ecc1037c1b3c5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Jan 2024 19:59:01 -0800 Subject: [PATCH 04/13] chore(deps): bump onnx-weekly from 1.16.0.dev20231225 to 1.16.0.dev20240101 in /requirements/ci (#1244) --- requirements/ci/requirements-onnx-weekly.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/ci/requirements-onnx-weekly.txt b/requirements/ci/requirements-onnx-weekly.txt index 3dbea5df6..aeed888ac 100644 --- a/requirements/ci/requirements-onnx-weekly.txt +++ b/requirements/ci/requirements-onnx-weekly.txt @@ -1 +1 @@ -onnx-weekly==1.16.0.dev20231225 +onnx-weekly==1.16.0.dev20240101 From 1231cc0d09ad8143725f02eb202b6819d48a888c Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 3 Jan 2024 13:04:12 +0800 Subject: [PATCH 05/13] AddOp(upsample_bicubic2d) | feat(torchlib) (#1208) Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/nn.py | 82 +++++++++++++++++-- .../function_libs/torch_lib/extra_opinfo.py | 58 +++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 5 ++ 3 files changed, 139 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index d018c5d51..bb767071e 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2197,16 +2197,86 @@ def aten_unflatten_dense_tensors( raise NotImplementedError() +@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bicubic2d.vec"), trace_only=True) def aten_upsample_bicubic2d( - self: TensorType, + self: TReal, output_size: INT64, align_corners: bool, - scales_h: Optional[float] = None, - scales_w: Optional[float] = None, -) -> TensorType: - """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + scale_factors: Optional[TFloat] = None, +) -> TReal: + """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + """ - raise NotImplementedError() + if output_size is not None: + result = _aten_upsample_output_size(self, output_size, align_corners, "cubic") + else: + result = _aten_upsample_scales(self, scale_factors, align_corners, "cubic") + return result + + +@torch_op("aten::upsample_bicubic2d", private=True) +def _aten_upsample_output_size( + self: TReal, + output_size: INT64, + align_corners: bool, + str_mode: str, +) -> TReal: + self_shape = op.Shape(self) + starts = op.Constant(value_ints=[0]) + ends = op.Constant(value_ints=[2]) + batch_channel = op.Slice(self_shape, starts, ends) + output_size = op.Concat(batch_channel, output_size, axis=0) + if align_corners: + result = op.Resize( + self, + None, + None, + output_size, + mode=str_mode, + coordinate_transformation_mode="align_corners", + ) + else: + result = op.Resize( + self, + None, + None, + output_size, + mode=str_mode, + coordinate_transformation_mode="pytorch_half_pixel", + ) + + return result + + +@torch_op("aten::upsample_bicubic2d", private=True) +def _aten_upsample_scales( + self: TReal, + scale_factors: TFloat, + align_corners: bool, + str_mode: str, +) -> TReal: + scale_factors = op.Cast(scale_factors, to=FLOAT.dtype) + scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0) + if align_corners: + result = op.Resize( + self, + None, + scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] + None, + mode=str_mode, + coordinate_transformation_mode="align_corners", + ) + else: + result = op.Resize( + self, + None, + scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] + None, + mode=str_mode, + coordinate_transformation_mode="pytorch_half_pixel", + ) + return result def aten_upsample_bicubic2d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 24d962c3d..0224b6cfa 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -1409,6 +1409,57 @@ def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(t, args=(dimension, size, step)) +def sample_inputs_upsample_bicubic2d(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + align_corners_options = (True, False) + rank = 2 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) + + for align_corners in align_corners_options: + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(S, rank, False), align_corners + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + shape(L, rank, False), + align_corners, + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, # output_size + align_corners, + (1.7, 1.7), # scaler + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, # if this is None, the scalar must be list + align_corners, + (0.6, 0.6), + ) + + class _TestParamsMaxPoolEmptyStrideBase: # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203 def __init__(self): @@ -1874,6 +1925,13 @@ def __init__(self): sample_inputs_func=sample_inputs_unfold, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_bicubic2d", + aten_name="upsample_bicubic2d", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_bicubic2d, + supports_out=False, + ), opinfo_core.OpInfo( "nn.functional.max_pool1d_with_indices", aten_name="max_pool1d_with_indices", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 655e8809a..7aeba0d14 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -2122,6 +2122,11 @@ def _where_input_wrangler( input_wrangler=_upsample_bilinear2d_input_wrangler, trace_only=True, ), + TorchLibOpInfo( + "ops.aten.upsample_bicubic2d", + nn_ops.aten_upsample_bicubic2d, + trace_only=True, + ), TorchLibOpInfo( "nn.functional.upsample_nearest2d", nn_ops.aten_upsample_nearest2d, From 7bb78bc6b991ed3071397a192ab820eca0bdef26 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 Jan 2024 11:06:12 -0500 Subject: [PATCH 06/13] chore(deps): bump ruff from 0.1.9 to 0.1.11 in /requirements/lintrunner (#1248) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [ruff](https://github.com/astral-sh/ruff) from 0.1.9 to 0.1.11.
Release notes

Sourced from ruff's releases.

v0.1.11

Changes

Preview features

  • [pylint] Implement super-without-brackets (W0245) (#9257)

Bug fixes

  • Check path string properly in python -m ruff invocations (#9367)

Documentation

  • Tweak relative-imports message (#9365)
  • Add fix safety note for yield-in-for-loop (#9364)

Contributors

v0.1.10

Changes

Preview features

  • Improve dummy_implementations preview style formatting (#9240)
  • Normalise Hex and unicode escape sequences in strings (#9280)
  • Parenthesize long type annotations in annotated assignments (#9210)
  • Parenthesize multi-context managers in with statements (#9222)
  • [flake8-pyi] Implement generator-return-from-iter-method (PYI058) (#9313)
  • [pylint] Implement empty-comment (PLR2044) (#9174)
  • [refurb] Implement bit-count (FURB161) (#9265)
  • [ruff] Add never-union rule to detect redundant typing.NoReturn and typing.Never (#9217)

CLI

  • Add paths to TOML parse errors (#9358)
  • Add row and column numbers to formatter parse errors (#9321)
  • Improve responsiveness when invoked via Python (#9315)
  • Short rule messages should not end with a period (#9345)

Configuration

  • Respect runtime-required decorators on functions (#9317)

Bug fixes

... (truncated)

Changelog

Sourced from ruff's changelog.

0.1.11

Preview features

  • [pylint] Implement super-without-brackets (W0245) (#9257)

Bug fixes

  • Check path string properly in python -m ruff invocations (#9367)

Documentation

  • Tweak relative-imports message (#9365)
  • Add fix safety note for yield-in-for-loop (#9364)

0.1.10

Preview features

  • Improve dummy_implementations preview style formatting (#9240)
  • Normalise Hex and unicode escape sequences in strings (#9280)
  • Parenthesize long type annotations in annotated assignments (#9210)
  • Parenthesize multi-context managers in with statements (#9222)
  • [flake8-pyi] Implement generator-return-from-iter-method (PYI058) (#9313)
  • [pylint] Implement empty-comment (PLR2044) (#9174)
  • [refurb] Implement bit-count (FURB161) (#9265)
  • [ruff] Add never-union rule to detect redundant typing.NoReturn and typing.Never (#9217)

CLI

  • Add paths to TOML parse errors (#9358)
  • Add row and column numbers to formatter parse errors (#9321)
  • Improve responsiveness when invoked via Python (#9315)
  • Short rule messages should not end with a period (#9345)

Configuration

  • Respect runtime-required decorators on functions (#9317)

Bug fixes

  • Avoid asyncio-dangling-task for nonlocal and global bindings (#9263)
  • Escape trailing placeholders in rule documentation (#9301)
  • Fix continuation detection following multi-line strings (#9332)
  • Fix scoping for generators in named expressions in classes (#9248)
  • Port from obsolete wsl crate to is-wsl (#9356)
  • Remove special pre-visit for module docstrings (#9261)
  • Respect __str__ definitions from super classes (#9338)
  • Respect unused-noqa via per-file-ignores (#9300)
  • Respect attribute chains when resolving builtin call paths (#9309)

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ruff&package-manager=pip&previous-version=0.1.9&new-version=0.1.11)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index 27066d436..20e15b0d4 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.1.9 +ruff==0.1.11 # MYPY mypy==1.7.1 types-PyYAML==6.0.12.11 From 56d6d62ae3c248f9c68c895a6fce418f08fa8521 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Fri, 12 Jan 2024 11:38:02 -0800 Subject: [PATCH 07/13] Add 'aten_upsample_bilinear2d_vec' for unet (#1249) --- onnxscript/function_libs/torch_lib/ops/nn.py | 14 +++++++++++++- .../function_libs/torch_lib/ops_test_data.py | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index bb767071e..e9ee699aa 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2306,11 +2306,23 @@ def aten_upsample_bilinear2d( result = _aten_upsample_bilinear2d_output_size(self, output_size) else: assert scales_h is not None - assert scales_h == scales_w + assert scales_h == scales_w, f"scale_h({scales_h}) != scale_w({scales_w})" result = _aten_upsample_bilinear2d_scales(self, scales_h, scales_w) return result +@torch_op("aten::upsample_bilinear2d.vec", trace_only=True) +def aten_upsample_bilinear2d_vec( + self: TReal, + output_size: Optional[INT64] = None, + align_corners: bool = True, + scale_factors: Optional[Sequence[float]] = None, +) -> TReal: + scales_h = scale_factors[0] if scale_factors is not None else None + scales_w = scale_factors[1] if scale_factors is not None else None + return aten_upsample_bilinear2d(self, output_size, scales_h, scales_w, align_corners) + + @torch_op("aten::upsample_bilinear2d", private=True) def _aten_upsample_bilinear2d_output_size( self: TReal, diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 7aeba0d14..cabc13268 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -427,6 +427,18 @@ def _upsample_bilinear2d_input_wrangler( return args, kwargs +def _upsample_bilinear2d_vec_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if "size" in kwargs: + args.append(np.array(kwargs["size"], dtype=np.int64)) + del kwargs["size"] # promote tensor type kwargs to args + if "scale_factor" in kwargs: + kwargs["scale_factors"] = [kwargs["scale_factor"]] * 2 + del kwargs["scale_factor"] # adapt the function signature + return args, kwargs + + def _upsample_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -2122,6 +2134,12 @@ def _where_input_wrangler( input_wrangler=_upsample_bilinear2d_input_wrangler, trace_only=True, ), + TorchLibOpInfo( + "nn.functional.upsample_bilinear2d", + nn_ops.aten_upsample_bilinear2d_vec, + input_wrangler=_upsample_bilinear2d_vec_input_wrangler, + trace_only=True, + ), TorchLibOpInfo( "ops.aten.upsample_bicubic2d", nn_ops.aten_upsample_bicubic2d, From 3b2291c898bc054fb03e2093246dd768953526d4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jan 2024 16:38:04 -0800 Subject: [PATCH 08/13] chore(deps): bump ruff from 0.1.11 to 0.1.13 in /requirements/lintrunner (#1252) --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index 20e15b0d4..b72cf4132 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.1.11 +ruff==0.1.13 # MYPY mypy==1.7.1 types-PyYAML==6.0.12.11 From 5b9c31854123eb727f48c0d8d45d67c65d5dc1de Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 16 Jan 2024 11:40:02 -0800 Subject: [PATCH 09/13] Implement torchvision operator `nms` | feat(torchlib) (#1253) - Create the scaffold and tests to support torchvision ops. - Implement `torchvision::nms` --- noxfile.py | 16 ++++++++-- .../function_libs/torch_lib/ops/__init__.py | 3 +- .../function_libs/torch_lib/ops/vision.py | 24 +++++++++++++++ .../function_libs/torch_lib/extra_opinfo.py | 29 +++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 2 ++ .../ci/requirements-pytorch-nightly.txt | 1 + 6 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 onnxscript/function_libs/torch_lib/ops/vision.py diff --git a/noxfile.py b/noxfile.py index 450ca94c4..f3aa7656a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -29,6 +29,7 @@ ONNX = "onnx==1.14.1" ONNX_RUNTIME = "onnxruntime==1.16.1" PYTORCH = "torch==2.1.0" +TORCHVISON = "torchvision==0.16" ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = ( "flatbuffers", "coloredlogs", @@ -52,6 +53,7 @@ def test(session): session.install( *COMMON_TEST_DEPENDENCIES, PYTORCH, + TORCHVISON, ONNX, ONNX_RUNTIME, ) @@ -78,7 +80,7 @@ def test_torch_nightly(session): @nox.session(tags=["test-onnx-weekly"]) def test_onnx_weekly(session): """Test with ONNX weekly (preview) build.""" - session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH) + session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH, TORCHVISON) session.install("-r", "requirements/ci/requirements-onnx-weekly.txt") session.install(".", "--no-deps") session.run("pip", "list") @@ -89,7 +91,11 @@ def test_onnx_weekly(session): def test_ort_nightly(session): """Test with ONNX Runtime nightly builds.""" session.install( - *COMMON_TEST_DEPENDENCIES, PYTORCH, ONNX, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES + *COMMON_TEST_DEPENDENCIES, + PYTORCH, + TORCHVISON, + ONNX, + *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES, ) session.install("-r", "requirements/ci/requirements-ort-nightly.txt") session.install(".", "--no-deps") @@ -101,7 +107,11 @@ def test_ort_nightly(session): def test_experimental_torchlib_tracing(session): """Test TorchLib with the experimental TORCHLIB_EXPERIMENTAL_PREFER_TRACING flag on.""" session.install( - *COMMON_TEST_DEPENDENCIES, PYTORCH, ONNX, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES + *COMMON_TEST_DEPENDENCIES, + PYTORCH, + TORCHVISON, + ONNX, + *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES, ) session.install("-r", "requirements/ci/requirements-ort-nightly.txt") session.install(".", "--no-deps") diff --git a/onnxscript/function_libs/torch_lib/ops/__init__.py b/onnxscript/function_libs/torch_lib/ops/__init__.py index 9e3c552d8..5a1cfd76c 100644 --- a/onnxscript/function_libs/torch_lib/ops/__init__.py +++ b/onnxscript/function_libs/torch_lib/ops/__init__.py @@ -7,6 +7,7 @@ "prims", "sparse", "special", + "vision", ] -from . import core, fft, linalg, nested, nn, prims, sparse, special +from . import core, fft, linalg, nested, nn, prims, sparse, special, vision diff --git a/onnxscript/function_libs/torch_lib/ops/vision.py b/onnxscript/function_libs/torch_lib/ops/vision.py new file mode 100644 index 000000000..a2f5cb0ee --- /dev/null +++ b/onnxscript/function_libs/torch_lib/ops/vision.py @@ -0,0 +1,24 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" +"""torchvision operators.""" +from __future__ import annotations + +from onnxscript.function_libs.torch_lib.registration import torch_op +from onnxscript.onnx_opset import opset18 as op +from onnxscript.onnx_types import FLOAT, INT64 + +_INT64_MAX = 0x7FFFFFFFFFFFFFFF + + +@torch_op("torchvision::nms") +def torchvision_nms(boxes: FLOAT, scores: FLOAT, iou_threshold: float) -> INT64: + # boxes: [num_batches, spatial_dimension, 4] + boxes = op.Unsqueeze(boxes, [0]) + # scores: [num_batches, num_classes, spatial_dimension] + scores = op.Unsqueeze(scores, [0, 1]) + # nms_out: [num_selected_indices, 3] where each column is [batch_index, class_index, box_index] + nms_out = op.NonMaxSuppression(boxes, scores, _INT64_MAX, iou_threshold) + return op.Reshape(op.Slice(nms_out, axes=[1], starts=[2], ends=[3]), [-1]) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 0224b6cfa..a2243cb08 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -8,6 +8,7 @@ from typing import Any, List import torch +import torchvision from torch import testing as torch_testing from torch.testing._internal import ( common_device_type, @@ -997,6 +998,27 @@ def sample_inputs__native_batch_norm_legit_no_stats( ) +def sample_inputs_non_max_suppression(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + boxes = torch.tensor( + [ + [0.0, 0.0, 10.0, 10.0], + [10.0, 10.0, 20.0, 20.0], + [32.0, 32.0, 40.0, 52.0], + ], + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + scores = torch.tensor( + [0.8, 0.4, 0.6], device=device, dtype=dtype, requires_grad=requires_grad + ) + + for iou_threshold in (0.3, 0.5, 0.7, 0.9): + yield opinfo_core.SampleInput(boxes, args=(scores, iou_threshold)) + + def sample_inputs_normal_tensor_float(op_info, device, dtype, requires_grad, **kwargs): del op_info del requires_grad @@ -1961,4 +1983,11 @@ def __init__(self): supports_autograd=False, supports_out=False, ), + opinfo_core.OpInfo( + "torchvision.ops.nms", + op=torchvision.ops.nms, + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs_non_max_suppression, + supports_out=False, + ), ] diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index cabc13268..836a678e8 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -52,6 +52,7 @@ from onnxscript.function_libs.torch_lib.ops import linalg as linalg_ops from onnxscript.function_libs.torch_lib.ops import nn as nn_ops from onnxscript.function_libs.torch_lib.ops import special as special_ops +from onnxscript.function_libs.torch_lib.ops import vision as vision_ops from onnxscript.tests.function_libs.torch_lib import extra_opinfo, ops_test_common # Create a copy of the op_db to modify @@ -2300,6 +2301,7 @@ def _where_input_wrangler( reason="this Aten overload only support when correction attribute exists", ), TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True), + TorchLibOpInfo("torchvision.ops.nms", vision_ops.torchvision_nms), ) ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims")) diff --git a/requirements/ci/requirements-pytorch-nightly.txt b/requirements/ci/requirements-pytorch-nightly.txt index bd46fa2e3..e2a0b33f0 100644 --- a/requirements/ci/requirements-pytorch-nightly.txt +++ b/requirements/ci/requirements-pytorch-nightly.txt @@ -1,3 +1,4 @@ --index-url=https://download.pytorch.org/whl/nightly/cpu --pre torch +torchvision From bec23adc815406e6103dff8463e3386a1be155e7 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Tue, 16 Jan 2024 15:23:20 -0800 Subject: [PATCH 10/13] Fix upsample_bilinear to respect align_corner argument (#1254) Fixes https://github.com/microsoft/onnxscript/issues/1159#issuecomment-1888089502 which indeed turns out to be a problem uncovered by PyTorch CI https://github.com/pytorch/pytorch/actions/runs/7508784822/job/20445196351?pr=117314. > Fixes `align_corner` default value. The default value from pytorch signature is `False` https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html#torch.nn.Upsample. > That said, it shouldn't matter since `align_corner` in aten signature in `native_functions.yaml` is a required argument, so in practice this function will never be invoked w/o `align_corner`. Above is outdated. The case is more complicated. https://github.com/microsoft/onnxscript/pull/1254#discussion_r1453948494. In short this PR fixes the torchlib op signature to match with aten spec, and updates input wrangler for test case to bridge from sample test inputs for function `torch.nn.functional.upsample_bilinear`. --------- Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/nn.py | 30 ++++++++++++------- .../function_libs/torch_lib/ops_test_data.py | 30 +++++++++++++++++-- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index e9ee699aa..9a105482c 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2295,38 +2295,45 @@ def aten_upsample_bicubic2d_backward( @torch_op("aten::upsample_bilinear2d", trace_only=True) def aten_upsample_bilinear2d( self: TReal, - output_size: Optional[INT64] = None, + output_size: INT64, + align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None, - align_corners: bool = True, # pylint: disable=unused-argument ) -> TReal: """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + coordinate_transformation_mode = "align_corners" if align_corners else "pytorch_half_pixel" if output_size is not None: - result = _aten_upsample_bilinear2d_output_size(self, output_size) + result = _aten_upsample_bilinear2d_output_size( + self, output_size, coordinate_transformation_mode + ) else: assert scales_h is not None assert scales_h == scales_w, f"scale_h({scales_h}) != scale_w({scales_w})" - result = _aten_upsample_bilinear2d_scales(self, scales_h, scales_w) + result = _aten_upsample_bilinear2d_scales( + self, scales_h, scales_w, coordinate_transformation_mode + ) return result @torch_op("aten::upsample_bilinear2d.vec", trace_only=True) def aten_upsample_bilinear2d_vec( self: TReal, - output_size: Optional[INT64] = None, - align_corners: bool = True, - scale_factors: Optional[Sequence[float]] = None, + output_size: Optional[INT64], + align_corners: bool, + scale_factors: Optional[Sequence[float]], ) -> TReal: + """upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" scales_h = scale_factors[0] if scale_factors is not None else None scales_w = scale_factors[1] if scale_factors is not None else None - return aten_upsample_bilinear2d(self, output_size, scales_h, scales_w, align_corners) + return aten_upsample_bilinear2d(self, output_size, align_corners, scales_h, scales_w) @torch_op("aten::upsample_bilinear2d", private=True) def _aten_upsample_bilinear2d_output_size( self: TReal, output_size: INT64, + coordinate_transformation_mode: str, ) -> TReal: """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" @@ -2341,7 +2348,8 @@ def _aten_upsample_bilinear2d_output_size( None, output_size, mode="linear", - coordinate_transformation_mode="align_corners", + coordinate_transformation_mode=coordinate_transformation_mode, + nearest_mode="floor", ) @@ -2350,6 +2358,7 @@ def _aten_upsample_bilinear2d_scales( self: TReal, scales_h: float, scales_w: float, + coordinate_transformation_mode: str, ) -> TReal: """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" @@ -2366,7 +2375,8 @@ def _aten_upsample_bilinear2d_scales( scales, # format should be: [1.0, 1.0, scale_h, scale_w] None, mode="linear", - coordinate_transformation_mode="align_corners", + coordinate_transformation_mode=coordinate_transformation_mode, + nearest_mode="floor", ) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 836a678e8..67efe8047 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -418,9 +418,21 @@ def _sum_input_wrangler( def _upsample_bilinear2d_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: + # Wrangler for the signature difference between + # 'nn.functional.upsample_bilinear' + # and + # 'aten::upsample_bilinear2d' + # https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html if "size" in kwargs: args.append(np.array(kwargs["size"], dtype=np.int64)) del kwargs["size"] # promote tensor type kwargs to args + else: + args.append(None) + if "align_corners" in kwargs: + args.append(kwargs["align_corners"]) + del kwargs["align_corners"] + else: + args.append(True) # Fill in the default value if "scale_factor" in kwargs: kwargs["scales_h"] = kwargs["scale_factor"] kwargs["scales_w"] = kwargs["scale_factor"] @@ -431,12 +443,26 @@ def _upsample_bilinear2d_input_wrangler( def _upsample_bilinear2d_vec_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: + # Wrangler for the signature difference between + # 'nn.functional.upsample_bilinear' + # and + # 'aten::upsample_bilinear2d.vec' + # https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html if "size" in kwargs: args.append(np.array(kwargs["size"], dtype=np.int64)) del kwargs["size"] # promote tensor type kwargs to args + else: + args.append(None) + if "align_corners" in kwargs: + args.append(kwargs["align_corners"]) + del kwargs["align_corners"] + else: + args.append(True) # Fill in the default value if "scale_factor" in kwargs: - kwargs["scale_factors"] = [kwargs["scale_factor"]] * 2 - del kwargs["scale_factor"] # adapt the function signature + args.append([kwargs["scale_factor"]] * 2) + del kwargs["scale_factor"] # promote tensor type kwargs to args + else: + args.append(None) return args, kwargs From 4a85d3ff638cdbeb342b5d732a1187b4e1b21fd3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 17 Jan 2024 17:13:19 -0800 Subject: [PATCH 11/13] Refactor and fix upsample 2d | fix(torchlib) (#1255) - Refactor upsample 2d functions to use a common set of logics - Add additional tests (`align_corners=False`; test the `aten::upsample_bicubic2d.vec` overload) - Fix implementation for `aten::upsample_bicubic2d` by isolating the `aten::upsample_bicubic2d.vec` overload because it has a different signature. - xfail `align_corners=False` and when scale_w and scale_h are specified --- onnxscript/function_libs/torch_lib/ops/nn.py | 213 ++++++++---------- .../function_libs/torch_lib/extra_opinfo.py | 99 +++++++- .../function_libs/torch_lib/ops_test_data.py | 77 ++----- 3 files changed, 201 insertions(+), 188 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 9a105482c..731d12370 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2197,85 +2197,99 @@ def aten_unflatten_dense_tensors( raise NotImplementedError() -@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bicubic2d.vec"), trace_only=True) -def aten_upsample_bicubic2d( - self: TReal, - output_size: INT64, - align_corners: bool, - scale_factors: Optional[TFloat] = None, -) -> TReal: - """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor - """ +def _get_upsample_align_corners_mode(align_corners: bool) -> str: + return "align_corners" if align_corners else "pytorch_half_pixel" - if output_size is not None: - result = _aten_upsample_output_size(self, output_size, align_corners, "cubic") - else: - result = _aten_upsample_scales(self, scale_factors, align_corners, "cubic") - return result - -@torch_op("aten::upsample_bicubic2d", private=True) +@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True) def _aten_upsample_output_size( self: TReal, output_size: INT64, - align_corners: bool, - str_mode: str, + mode: str, + coordinate_transformation_mode: str, ) -> TReal: self_shape = op.Shape(self) starts = op.Constant(value_ints=[0]) ends = op.Constant(value_ints=[2]) batch_channel = op.Slice(self_shape, starts, ends) output_size = op.Concat(batch_channel, output_size, axis=0) - if align_corners: - result = op.Resize( - self, - None, - None, - output_size, - mode=str_mode, - coordinate_transformation_mode="align_corners", - ) - else: - result = op.Resize( - self, - None, - None, - output_size, - mode=str_mode, - coordinate_transformation_mode="pytorch_half_pixel", - ) - - return result + return op.Resize( + self, + None, + None, + output_size, + mode=mode, + coordinate_transformation_mode=coordinate_transformation_mode, + nearest_mode="floor", + ) -@torch_op("aten::upsample_bicubic2d", private=True) +@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True) def _aten_upsample_scales( self: TReal, scale_factors: TFloat, - align_corners: bool, - str_mode: str, + mode: str, + coordinate_transformation_mode: str, ) -> TReal: scale_factors = op.Cast(scale_factors, to=FLOAT.dtype) scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0) - if align_corners: - result = op.Resize( + return op.Resize( + self, + None, + scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] + None, + mode=mode, + coordinate_transformation_mode=coordinate_transformation_mode, + nearest_mode="floor", + ) + + +@torch_op("aten::upsample_bicubic2d", trace_only=True) +def aten_upsample_bicubic2d( + self: TReal, + output_size: INT64, + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> TReal: + """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + + # NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch, + # unless when align_corners is True, in which case we do not know what is going on. + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + return _aten_upsample_output_size( + self, + output_size, + mode="cubic", + coordinate_transformation_mode=coordinate_transformation_mode, + ) + + +@torch_op("aten::upsample_bicubic2d.vec", trace_only=True) +def aten_upsample_bicubic2d_vec( + self: TReal, + output_size: INT64, + align_corners: bool, + scale_factors: Optional[Sequence[float]], +) -> TReal: + """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" + + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + if scale_factors is not None: + result = _aten_upsample_scales( self, - None, - scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] - None, - mode=str_mode, - coordinate_transformation_mode="align_corners", + op.Constant(value_floats=scale_factors), + mode="cubic", + coordinate_transformation_mode=coordinate_transformation_mode, ) else: - result = op.Resize( + result = _aten_upsample_output_size( self, - None, - scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] - None, - mode=str_mode, - coordinate_transformation_mode="pytorch_half_pixel", + output_size, + mode="cubic", + coordinate_transformation_mode=coordinate_transformation_mode, ) + return result @@ -2302,18 +2316,15 @@ def aten_upsample_bilinear2d( ) -> TReal: """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" - coordinate_transformation_mode = "align_corners" if align_corners else "pytorch_half_pixel" - if output_size is not None: - result = _aten_upsample_bilinear2d_output_size( - self, output_size, coordinate_transformation_mode - ) - else: - assert scales_h is not None - assert scales_h == scales_w, f"scale_h({scales_h}) != scale_w({scales_w})" - result = _aten_upsample_bilinear2d_scales( - self, scales_h, scales_w, coordinate_transformation_mode - ) - return result + # NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch, + # unless when align_corners is True, in which case we do not know what is going on. + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + return _aten_upsample_output_size( + self, + output_size, + coordinate_transformation_mode=coordinate_transformation_mode, + mode="linear", + ) @torch_op("aten::upsample_bilinear2d.vec", trace_only=True) @@ -2324,60 +2335,24 @@ def aten_upsample_bilinear2d_vec( scale_factors: Optional[Sequence[float]], ) -> TReal: """upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" - scales_h = scale_factors[0] if scale_factors is not None else None - scales_w = scale_factors[1] if scale_factors is not None else None - return aten_upsample_bilinear2d(self, output_size, align_corners, scales_h, scales_w) - -@torch_op("aten::upsample_bilinear2d", private=True) -def _aten_upsample_bilinear2d_output_size( - self: TReal, - output_size: INT64, - coordinate_transformation_mode: str, -) -> TReal: - """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" - - self_shape = op.Shape(self) - starts = op.Constant(value_ints=[0]) - ends = op.Constant(value_ints=[2]) - batch_channel = op.Slice(self_shape, starts, ends) - output_size = op.Concat(batch_channel, output_size, axis=0) - return op.Resize( - self, - None, - None, - output_size, - mode="linear", - coordinate_transformation_mode=coordinate_transformation_mode, - nearest_mode="floor", - ) - - -@torch_op("aten::upsample_bilinear2d", private=True) -def _aten_upsample_bilinear2d_scales( - self: TReal, - scales_h: float, - scales_w: float, - coordinate_transformation_mode: str, -) -> TReal: - """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + if scale_factors is not None: + result = _aten_upsample_scales( + self, + op.Constant(value_floats=scale_factors), + mode="linear", + coordinate_transformation_mode=coordinate_transformation_mode, + ) + else: + result = _aten_upsample_output_size( + self, + output_size, + mode="linear", + coordinate_transformation_mode=coordinate_transformation_mode, + ) - neg_1 = op.Constant(value_ints=[-1]) - scales = op.Concat( - op.Constant(value_floats=[1.0, 1.0]), - op.Reshape(op.Constant(value_float=scales_h), neg_1), - op.Reshape(op.Constant(value_float=scales_w), neg_1), - axis=0, - ) - return op.Resize( - self, - None, - scales, # format should be: [1.0, 1.0, scale_h, scale_w] - None, - mode="linear", - coordinate_transformation_mode=coordinate_transformation_mode, - nearest_mode="floor", - ) + return result def aten_upsample_bilinear2d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index a2243cb08..26920953a 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -1431,7 +1431,7 @@ def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(t, args=(dimension, size, step)) -def sample_inputs_upsample_bicubic2d(op_info, device, dtype, requires_grad, **kwargs): +def sample_inputs_upsample_2d(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -1470,15 +1470,77 @@ def shape(size, rank, with_batch_channel=True): ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), - None, # output_size - align_corners, - (1.7, 1.7), # scaler + args=(shape(L, rank, False), align_corners), + kwargs=dict(scales_h=0.6, scales_w=4.2), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=(shape(L, rank, False), align_corners), + kwargs=dict(scales_h=4.2, scales_w=0.6), + ) + + +def sample_inputs_upsample_2d_vec(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + align_corners_options = (True, False) + rank = 2 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True, None) + + for align_corners in align_corners_options: + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(S, rank, False), align_corners, None ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), - None, # if this is None, the scalar must be list + shape(L, rank, False), align_corners, - (0.6, 0.6), + None, + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=( + None, # output_size + align_corners, + ), + kwargs=dict(scale_factors=(1.7, 1.7)), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=( + None, # if this is None, the scalar must be list + align_corners, + ), + kwargs=dict(scale_factors=(0.6, 0.6)), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + args=( + None, # if this is None, the scalar must be list + align_corners, + ), + kwargs=dict(scale_factors=(0.6, 4.2)), ) @@ -1948,10 +2010,31 @@ def __init__(self): supports_out=False, ), opinfo_core.OpInfo( - "ops.aten.upsample_bicubic2d", + "ops.aten.upsample_bicubic2d.default", aten_name="upsample_bicubic2d", dtypes=common_dtype.floating_types_and(torch.bfloat16), - sample_inputs_func=sample_inputs_upsample_bicubic2d, + sample_inputs_func=sample_inputs_upsample_2d, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.upsample_bicubic2d.vec", + aten_name="upsample_bicubic2d.vec", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_2d_vec, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.upsample_bilinear2d.default", + aten_name="upsample_bilinear2d", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_2d, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.upsample_bilinear2d.vec", + aten_name="upsample_bilinear2d.vec", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_2d_vec, supports_out=False, ), opinfo_core.OpInfo( diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 67efe8047..368a05170 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -415,57 +415,6 @@ def _sum_input_wrangler( return args, kwargs -def _upsample_bilinear2d_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - # Wrangler for the signature difference between - # 'nn.functional.upsample_bilinear' - # and - # 'aten::upsample_bilinear2d' - # https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html - if "size" in kwargs: - args.append(np.array(kwargs["size"], dtype=np.int64)) - del kwargs["size"] # promote tensor type kwargs to args - else: - args.append(None) - if "align_corners" in kwargs: - args.append(kwargs["align_corners"]) - del kwargs["align_corners"] - else: - args.append(True) # Fill in the default value - if "scale_factor" in kwargs: - kwargs["scales_h"] = kwargs["scale_factor"] - kwargs["scales_w"] = kwargs["scale_factor"] - del kwargs["scale_factor"] # adapt the function signature - return args, kwargs - - -def _upsample_bilinear2d_vec_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - # Wrangler for the signature difference between - # 'nn.functional.upsample_bilinear' - # and - # 'aten::upsample_bilinear2d.vec' - # https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html - if "size" in kwargs: - args.append(np.array(kwargs["size"], dtype=np.int64)) - del kwargs["size"] # promote tensor type kwargs to args - else: - args.append(None) - if "align_corners" in kwargs: - args.append(kwargs["align_corners"]) - del kwargs["align_corners"] - else: - args.append(True) # Fill in the default value - if "scale_factor" in kwargs: - args.append([kwargs["scale_factor"]] * 2) - del kwargs["scale_factor"] # promote tensor type kwargs to args - else: - args.append(None) - return args, kwargs - - def _upsample_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -2156,21 +2105,32 @@ def _where_input_wrangler( test_class_name="TestOutputConsistencyEager", ), TorchLibOpInfo( - "nn.functional.upsample_bilinear2d", + "ops.aten.upsample_bilinear2d.default", nn_ops.aten_upsample_bilinear2d, - input_wrangler=_upsample_bilinear2d_input_wrangler, trace_only=True, + ).xfail( + matcher=lambda sample: sample.args[1] is False + and sample.kwargs.get("scales_h") is not None, + reason="fixme: align_corners=False output mismatch when scales are provided", ), TorchLibOpInfo( - "nn.functional.upsample_bilinear2d", + "ops.aten.upsample_bilinear2d.vec", nn_ops.aten_upsample_bilinear2d_vec, - input_wrangler=_upsample_bilinear2d_vec_input_wrangler, trace_only=True, ), TorchLibOpInfo( - "ops.aten.upsample_bicubic2d", + "ops.aten.upsample_bicubic2d.default", nn_ops.aten_upsample_bicubic2d, trace_only=True, + ).xfail( + matcher=lambda sample: sample.args[1] is False + and sample.kwargs.get("scales_h") is not None, + reason="fixme: align_corners=False output mismatch when scales are provided", + ), + TorchLibOpInfo( + "ops.aten.upsample_bicubic2d.vec", + nn_ops.aten_upsample_bicubic2d_vec, + trace_only=True, ), TorchLibOpInfo( "nn.functional.upsample_nearest2d", @@ -2403,11 +2363,6 @@ def _where_input_wrangler( "nn.functional.celu", ("nn.functional.celu_type_promoted",), ) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.upsample_bilinear", - ("nn.functional.upsample_bilinear2d",), -) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.upsample_nearest", From 87852317fa8eb9e01f835608a381ed1075e22477 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 18 Jan 2024 09:32:26 +0800 Subject: [PATCH 12/13] Add op(upsample linear1d) | feat(torchlib) (#1245) Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/nn.py | 16 ++++-- .../function_libs/torch_lib/extra_opinfo.py | 56 ++++++++++++++++--- .../function_libs/torch_lib/ops_test_data.py | 9 +++ 3 files changed, 70 insertions(+), 11 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 731d12370..50f62701d 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2368,12 +2368,20 @@ def aten_upsample_bilinear2d_backward( raise NotImplementedError() +@torch_op("aten::upsample_linear1d", trace_only=True) def aten_upsample_linear1d( - self: TensorType, output_size: INT64, align_corners: bool, scales: Optional[float] = None -) -> TensorType: + self: TReal, output_size: INT64, align_corners: bool, scales: Optional[float] = None +) -> TReal: """upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor""" - - raise NotImplementedError() + # FIXME(justinchuby): Support when scales is provided and align_corners is False + del scales + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + return _aten_upsample_output_size( + self, + output_size, + mode="linear", + coordinate_transformation_mode=coordinate_transformation_mode, + ) def aten_upsample_linear1d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 26920953a..c274df2be 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -1464,9 +1464,7 @@ def shape(size, rank, with_batch_channel=True): make_arg(shape(D, rank)), shape(S, rank, False), align_corners ) yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), - shape(L, rank, False), - align_corners, + make_arg(shape(D, rank)), shape(L, rank, False), align_corners ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), @@ -1513,10 +1511,7 @@ def shape(size, rank, with_batch_channel=True): make_arg(shape(D, rank)), shape(S, rank, False), align_corners, None ) yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), - shape(L, rank, False), - align_corners, - None, + make_arg(shape(D, rank)), shape(L, rank, False), align_corners, None ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), @@ -1544,6 +1539,46 @@ def shape(size, rank, with_batch_channel=True): ) +def sample_inputs_upsample_linear1d(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + align_corners_options = (True, False) + rank = 1 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) + + for align_corners in align_corners_options: + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(S, rank, False), align_corners + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(L, rank, False), align_corners + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(L, rank, False), align_corners, scales=4.2 + ) + + class _TestParamsMaxPoolEmptyStrideBase: # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203 def __init__(self): @@ -2037,6 +2072,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_2d_vec, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_linear1d", + aten_name="upsample_linear1d", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_linear1d, + supports_out=False, + ), opinfo_core.OpInfo( "nn.functional.max_pool1d_with_indices", aten_name="max_pool1d_with_indices", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 368a05170..4b5d70c22 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -2132,6 +2132,15 @@ def _where_input_wrangler( nn_ops.aten_upsample_bicubic2d_vec, trace_only=True, ), + TorchLibOpInfo( + "ops.aten.upsample_linear1d", + nn_ops.aten_upsample_linear1d, + trace_only=True, + ).xfail( + matcher=lambda sample: sample.args[1] is False + and sample.kwargs.get("scales") is not None, + reason="fixme: align_corners=False output mismatch when scales are provided", + ), TorchLibOpInfo( "nn.functional.upsample_nearest2d", nn_ops.aten_upsample_nearest2d, From 4de6c80bfbf15027e4a0e8a7886559f75ccc5c28 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Jan 2024 15:15:33 -0800 Subject: [PATCH 13/13] chore(deps): bump ruff from 0.1.13 to 0.1.14 in /requirements/lintrunner (#1257) --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index b72cf4132..9906dfb72 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.1.13 +ruff==0.1.14 # MYPY mypy==1.7.1 types-PyYAML==6.0.12.11