Skip to content

Commit

Permalink
Merge branch 'main' into bowbao/bn_fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Jan 22, 2024
2 parents 28ee8ac + 4de6c80 commit e6327ee
Show file tree
Hide file tree
Showing 12 changed files with 500 additions and 81 deletions.
16 changes: 13 additions & 3 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ONNX = "onnx==1.14.1"
ONNX_RUNTIME = "onnxruntime==1.16.1"
PYTORCH = "torch==2.1.0"
TORCHVISON = "torchvision==0.16"
ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = (
"flatbuffers",
"coloredlogs",
Expand All @@ -52,6 +53,7 @@ def test(session):
session.install(
*COMMON_TEST_DEPENDENCIES,
PYTORCH,
TORCHVISON,
ONNX,
ONNX_RUNTIME,
)
Expand All @@ -78,7 +80,7 @@ def test_torch_nightly(session):
@nox.session(tags=["test-onnx-weekly"])
def test_onnx_weekly(session):
"""Test with ONNX weekly (preview) build."""
session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH)
session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH, TORCHVISON)
session.install("-r", "requirements/ci/requirements-onnx-weekly.txt")
session.install(".", "--no-deps")
session.run("pip", "list")
Expand All @@ -89,7 +91,11 @@ def test_onnx_weekly(session):
def test_ort_nightly(session):
"""Test with ONNX Runtime nightly builds."""
session.install(
*COMMON_TEST_DEPENDENCIES, PYTORCH, ONNX, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES
*COMMON_TEST_DEPENDENCIES,
PYTORCH,
TORCHVISON,
ONNX,
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
)
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
session.install(".", "--no-deps")
Expand All @@ -101,7 +107,11 @@ def test_ort_nightly(session):
def test_experimental_torchlib_tracing(session):
"""Test TorchLib with the experimental TORCHLIB_EXPERIMENTAL_PREFER_TRACING flag on."""
session.install(
*COMMON_TEST_DEPENDENCIES, PYTORCH, ONNX, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES
*COMMON_TEST_DEPENDENCIES,
PYTORCH,
TORCHVISON,
ONNX,
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
)
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
session.install(".", "--no-deps")
Expand Down
7 changes: 7 additions & 0 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Sequence["TorchScriptTensor"],
Sequence[float],
Sequence[int],
complex,
str,
int,
float,
Expand All @@ -45,6 +46,7 @@
Sequence["TorchScriptTensor"],
Sequence[float],
Sequence[int],
complex,
str,
int,
float,
Expand All @@ -56,6 +58,7 @@
Sequence[torch.Value],
Sequence[float],
Sequence[int],
complex,
str,
int,
float,
Expand Down Expand Up @@ -654,6 +657,10 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
isinstance(val, float) for val in constant
):
constant_tensor = torch.tensor(constant, dtype=torch.float)
elif isinstance(constant, complex):
# NOTE: ONNX doesn't support tensor of complex64/complex128, so we
# convert them to float32/float64 with real representation.
constant_tensor = torch.view_as_real(torch.tensor(constant).resolve_conj())
else:
raise TypeError(
f"Constant input '{constant}' of type '{type(constant)}' is not supported"
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"prims",
"sparse",
"special",
"vision",
]

from . import core, fft, linalg, nested, nn, prims, sparse, special
from . import core, fft, linalg, nested, nn, prims, sparse, special, vision
41 changes: 40 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from onnxscript import (
BFLOAT16,
BOOL,
COMPLEX64,
COMPLEX128,
DOUBLE,
FLOAT,
FLOAT16,
Expand Down Expand Up @@ -1412,6 +1414,15 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType:
raise NotImplementedError()


@torch_op("aten::cat", trace_only=True, complex=True)
def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor:
"""cat(Tensor[] tensors, int dim=0) -> Tensor"""
# Real representation unsqueezes the last dimension
if dim < 0:
dim = dim - 1
return aten_cat(tensors, dim=dim)


@torch_op("aten::cat")
def aten_cat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor:
"""cat(Tensor[] tensors, int dim=0) -> Tensor"""
Expand Down Expand Up @@ -7102,6 +7113,25 @@ def aten_scalar_tensor(s: float, dtype: int = FLOAT.dtype) -> RealType:
return common_ops.cast_to(s, dtype=dtype)


@torch_op("aten::scalar_tensor", trace_only=True, complex=True)
def aten_scalar_tensor_complex(
s: Union[FLOAT, DOUBLE], dtype: int = COMPLEX64.dtype
) -> RealType:
"""scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
# NOTE: When the input is originally in complex, this function is invoked.
# On the other hand, when the input is originally in real, aten_scalar_tensor is used.
# is invoked.
if dtype == COMPLEX128.dtype:
result = op.Cast(s, to=DOUBLE.dtype)
elif dtype == COMPLEX64.dtype:
result = op.Cast(s, to=FLOAT.dtype)
else:
# NOTE: No-op for non-complex dtype
# It's potentially a bug if it comes here with no-op.
result = s
return result


@torch_op("aten::scalar_tensor", trace_only=True)
def aten_scalar_tensor_sym_number(s: RealType, dtype: int = FLOAT.dtype) -> RealType:
"""scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
Expand Down Expand Up @@ -7532,6 +7562,15 @@ def aten_sspaddmm(
raise NotImplementedError()


@torch_op("aten::stack", trace_only=True, complex=True)
def aten_stack_complex(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrString:
"""stack(Tensor[] tensors, int dim=0) -> Tensor"""
# Real representation unsqueezes the last dimension
if dim < 0:
dim = dim - 1
return aten_stack(tensors, dim)


@torch_op("aten::stack")
def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrString:
"""stack(Tensor[] tensors, int dim=0) -> Tensor"""
Expand Down Expand Up @@ -8100,7 +8139,7 @@ def aten_triplet_margin_loss(


@torch_op("aten::triu")
def aten_triu(self: TensorType, diagonal: int = 0) -> TensorType:
def aten_triu(self: TTensor, diagonal: int = 0) -> TTensor:
"""triu(Tensor self, int diagonal=0) -> Tensor"""

return op.Trilu(self, diagonal, upper=1)
Expand Down
183 changes: 129 additions & 54 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,16 +2197,100 @@ def aten_unflatten_dense_tensors(
raise NotImplementedError()


def _get_upsample_align_corners_mode(align_corners: bool) -> str:
return "align_corners" if align_corners else "pytorch_half_pixel"


@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True)
def _aten_upsample_output_size(
self: TReal,
output_size: INT64,
mode: str,
coordinate_transformation_mode: str,
) -> TReal:
self_shape = op.Shape(self)
starts = op.Constant(value_ints=[0])
ends = op.Constant(value_ints=[2])
batch_channel = op.Slice(self_shape, starts, ends)
output_size = op.Concat(batch_channel, output_size, axis=0)
return op.Resize(
self,
None,
None,
output_size,
mode=mode,
coordinate_transformation_mode=coordinate_transformation_mode,
nearest_mode="floor",
)


@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True)
def _aten_upsample_scales(
self: TReal,
scale_factors: TFloat,
mode: str,
coordinate_transformation_mode: str,
) -> TReal:
scale_factors = op.Cast(scale_factors, to=FLOAT.dtype)
scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0)
return op.Resize(
self,
None,
scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
None,
mode=mode,
coordinate_transformation_mode=coordinate_transformation_mode,
nearest_mode="floor",
)


@torch_op("aten::upsample_bicubic2d", trace_only=True)
def aten_upsample_bicubic2d(
self: TensorType,
self: TReal,
output_size: INT64,
align_corners: bool,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TensorType:
) -> TReal:
"""upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""

raise NotImplementedError()
# NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch,
# unless when align_corners is True, in which case we do not know what is going on.
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
return _aten_upsample_output_size(
self,
output_size,
mode="cubic",
coordinate_transformation_mode=coordinate_transformation_mode,
)


@torch_op("aten::upsample_bicubic2d.vec", trace_only=True)
def aten_upsample_bicubic2d_vec(
self: TReal,
output_size: INT64,
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> TReal:
"""upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor"""

coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
if scale_factors is not None:
result = _aten_upsample_scales(
self,
op.Constant(value_floats=scale_factors),
mode="cubic",
coordinate_transformation_mode=coordinate_transformation_mode,
)
else:
result = _aten_upsample_output_size(
self,
output_size,
mode="cubic",
coordinate_transformation_mode=coordinate_transformation_mode,
)

return result


def aten_upsample_bicubic2d_backward(
Expand All @@ -2225,67 +2309,50 @@ def aten_upsample_bicubic2d_backward(
@torch_op("aten::upsample_bilinear2d", trace_only=True)
def aten_upsample_bilinear2d(
self: TReal,
output_size: Optional[INT64] = None,
output_size: INT64,
align_corners: bool,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
align_corners: bool = True, # pylint: disable=unused-argument
) -> TReal:
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""

if output_size is not None:
result = _aten_upsample_bilinear2d_output_size(self, output_size)
else:
assert scales_h is not None
assert scales_h == scales_w
result = _aten_upsample_bilinear2d_scales(self, scales_h, scales_w)
return result


@torch_op("aten::upsample_bilinear2d", private=True)
def _aten_upsample_bilinear2d_output_size(
self: TReal,
output_size: INT64,
) -> TReal:
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""

self_shape = op.Shape(self)
starts = op.Constant(value_ints=[0])
ends = op.Constant(value_ints=[2])
batch_channel = op.Slice(self_shape, starts, ends)
output_size = op.Concat(batch_channel, output_size, axis=0)
return op.Resize(
# NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch,
# unless when align_corners is True, in which case we do not know what is going on.
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
return _aten_upsample_output_size(
self,
None,
None,
output_size,
coordinate_transformation_mode=coordinate_transformation_mode,
mode="linear",
coordinate_transformation_mode="align_corners",
)


@torch_op("aten::upsample_bilinear2d", private=True)
def _aten_upsample_bilinear2d_scales(
@torch_op("aten::upsample_bilinear2d.vec", trace_only=True)
def aten_upsample_bilinear2d_vec(
self: TReal,
scales_h: float,
scales_w: float,
output_size: Optional[INT64],
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> TReal:
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""
"""upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor"""

coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
if scale_factors is not None:
result = _aten_upsample_scales(
self,
op.Constant(value_floats=scale_factors),
mode="linear",
coordinate_transformation_mode=coordinate_transformation_mode,
)
else:
result = _aten_upsample_output_size(
self,
output_size,
mode="linear",
coordinate_transformation_mode=coordinate_transformation_mode,
)

neg_1 = op.Constant(value_ints=[-1])
scales = op.Concat(
op.Constant(value_floats=[1.0, 1.0]),
op.Reshape(op.Constant(value_float=scales_h), neg_1),
op.Reshape(op.Constant(value_float=scales_w), neg_1),
axis=0,
)
return op.Resize(
self,
None,
scales, # format should be: [1.0, 1.0, scale_h, scale_w]
None,
mode="linear",
coordinate_transformation_mode="align_corners",
)
return result


def aten_upsample_bilinear2d_backward(
Expand All @@ -2301,12 +2368,20 @@ def aten_upsample_bilinear2d_backward(
raise NotImplementedError()


@torch_op("aten::upsample_linear1d", trace_only=True)
def aten_upsample_linear1d(
self: TensorType, output_size: INT64, align_corners: bool, scales: Optional[float] = None
) -> TensorType:
self: TReal, output_size: INT64, align_corners: bool, scales: Optional[float] = None
) -> TReal:
"""upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor"""

raise NotImplementedError()
# FIXME(justinchuby): Support when scales is provided and align_corners is False
del scales
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
return _aten_upsample_output_size(
self,
output_size,
mode="linear",
coordinate_transformation_mode=coordinate_transformation_mode,
)


def aten_upsample_linear1d_backward(
Expand Down
Loading

0 comments on commit e6327ee

Please sign in to comment.