Skip to content

Commit d8cc836

Browse files
Revert "[ONNX] Fix type annotations and enable type checking for all apis (pytorch#84091)"
This reverts commit 6446da1. Reverted pytorch#84091 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
1 parent b159a52 commit d8cc836

15 files changed

+68
-765
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3582,7 +3582,7 @@ def forward(self, x, k):
35823582

35833583
x = torch.arange(1.0, 6.0, requires_grad=True)
35843584
k = torch.tensor(3)
3585-
self.run_test(MyModuleDynamic(), (x, k))
3585+
self.run_test(MyModuleDynamic(), [x, k])
35863586

35873587
@skipScriptTest() # Python builtin apply of FunctionMeta object is currently not supported in Torchscript.
35883588
@skipIfUnsupportedMinOpsetVersion(11) # Clip op min is an input since opset 11.
@@ -7396,28 +7396,23 @@ def test_constant_pad(self):
73967396
x = torch.randn(2, 2, 4, 4)
73977397
self.run_test(model, x)
73987398

7399-
@common_utils.parametrize(
7400-
"pad",
7401-
[
7402-
common_utils.subtest([2, 4], name="scalar_list"),
7403-
common_utils.subtest(
7404-
[
7405-
torch.tensor(2, dtype=torch.int64),
7406-
torch.tensor(4, dtype=torch.int64),
7407-
],
7408-
name="scalar_tensor_list",
7409-
),
7410-
],
7411-
)
7412-
@skipIfUnsupportedMinOpsetVersion(11) # Dynamic padding is added in opset 11
7413-
def test_pad_types(self, pad):
7399+
# Dynamic padding is added in opset 11
7400+
@skipIfUnsupportedMinOpsetVersion(11)
7401+
def test_pad_types(self):
74147402
# Test for different pad integer types
74157403
class Pad(torch.nn.Module):
74167404
def forward(self, x, pad: List[int]):
74177405
return torch.nn.functional.pad(x, pad)
74187406

74197407
x = torch.randn(2, 2, 4, 4)
7420-
self.run_test(Pad(), (x, pad))
7408+
y = pad = [2, 4]
7409+
self.run_test(Pad(), (x, y))
7410+
7411+
y = pad = [
7412+
torch.tensor(2, dtype=torch.int64),
7413+
torch.tensor(4, dtype=torch.int64),
7414+
]
7415+
self.run_test(Pad(), (x, y))
74217416

74227417
@skipIfUnsupportedMaxOpsetVersion(10)
74237418
@skipScriptTest() # TODO: the logic in symbolic_opset9 doesn't handle script

torch/onnx/_patch_torch.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,15 @@
1010
# Import utils to get _params_dict because it is a global that is accessed by c++ code
1111
from torch.onnx import _deprecation, utils
1212
from torch.onnx._globals import GLOBALS
13-
from torch.onnx._internal import _beartype
1413

1514
_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")
1615

1716

1817
# TODO(#78694): Refactor the patching process to make it more transparent to users.
19-
@_beartype.beartype
2018
def _graph_op(
2119
g: _C.Graph,
2220
opname: str,
23-
*raw_args: Union[torch.Tensor, torch._C.Value],
21+
*raw_args: _C.Value,
2422
outputs: int = 1,
2523
**kwargs,
2624
) -> Union[_C.Value, Tuple[_C.Value, ...]]:
@@ -78,7 +76,6 @@ def _graph_op(
7876
return tuple(n.outputs())
7977

8078

81-
@_beartype.beartype
8279
def _const_if_tensor(g: _C.Graph, arg):
8380
if arg is None:
8481
return arg
@@ -88,7 +85,6 @@ def _const_if_tensor(g: _C.Graph, arg):
8885

8986

9087
# Generate an ONNX ATen op node.
91-
@_beartype.beartype
9288
def _aten_op(g: _C.Graph, operator: str, *args, overload_name: str = "", **kwargs):
9389
return _graph_op(
9490
g,
@@ -100,7 +96,6 @@ def _aten_op(g: _C.Graph, operator: str, *args, overload_name: str = "", **kwarg
10096
)
10197

10298

103-
@_beartype.beartype
10499
def _block_op(b: _C.Block, opname: str, *args, **kwargs):
105100
if "::" in opname:
106101
aten = False
@@ -120,7 +115,6 @@ def _block_op(b: _C.Block, opname: str, *args, **kwargs):
120115
return outputs
121116

122117

123-
@_beartype.beartype
124118
def _new_node(
125119
g: _C.Graph, namespace: str, op: str, outputs: int, *args, **kwargs
126120
) -> _C.Node:
@@ -144,7 +138,6 @@ def _new_node(
144138
return node
145139

146140

147-
@_beartype.beartype
148141
def _is_onnx_list(value):
149142
return (
150143
not isinstance(value, torch._six.string_classes)
@@ -153,22 +146,19 @@ def _is_onnx_list(value):
153146
)
154147

155148

156-
@_beartype.beartype
157149
def _scalar(x: torch.Tensor):
158150
"""Convert a scalar tensor into a Python value."""
159151
assert x.numel() == 1
160152
return x[0]
161153

162154

163-
@_beartype.beartype
164155
def _is_caffe2_aten_fallback():
165156
return (
166157
GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
167158
and _C_onnx._CAFFE2_ATEN_FALLBACK
168159
)
169160

170161

171-
@_beartype.beartype
172162
def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
173163
r"""Initializes the right attribute based on type of value."""
174164
m = _ATTR_PATTERN.match(key)
@@ -198,7 +188,6 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
198188
@_deprecation.deprecated(
199189
"1.13", "1.14", "Use 'g.op()' to create a constant node instead."
200190
)
201-
@_beartype.beartype
202191
def _graph_constant(
203192
g,
204193
value,

torch/onnx/_type_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010
from torch._C import _onnx as _C_onnx
11-
from torch.onnx._internal import _beartype
1211

1312
ScalarName = Literal[
1413
"Byte",
@@ -81,7 +80,6 @@ class JitScalarType(enum.IntEnum):
8180
UNDEFINED = enum.auto() # 16
8281

8382
@classmethod
84-
@_beartype.beartype
8583
def from_name(
8684
cls, name: Union[ScalarName, TorchName, Optional[str]]
8785
) -> JitScalarType:
@@ -106,36 +104,30 @@ def from_name(
106104
raise ValueError(f"Unknown torch or scalar type: '{name}'")
107105

108106
@classmethod
109-
@_beartype.beartype
110107
def from_dtype(cls, dtype: torch.dtype) -> JitScalarType:
111108
"""Convert a torch dtype to ScalarType."""
112109
if dtype not in _DTYPE_TO_SCALAR_TYPE:
113110
raise ValueError(f"Unknown dtype: {dtype}")
114111
return _DTYPE_TO_SCALAR_TYPE[dtype]
115112

116-
@_beartype.beartype
117113
def scalar_name(self) -> ScalarName:
118114
"""Convert a ScalarType to a JIT scalar type name."""
119115
return _SCALAR_TYPE_TO_NAME[self]
120116

121-
@_beartype.beartype
122117
def torch_name(self) -> TorchName:
123118
"""Convert a ScalarType to a torch type name."""
124119
return _SCALAR_TYPE_TO_TORCH_NAME[self]
125120

126-
@_beartype.beartype
127121
def dtype(self) -> torch.dtype:
128122
"""Convert a ScalarType to a torch dtype."""
129123
return _SCALAR_TYPE_TO_DTYPE[self]
130124

131-
@_beartype.beartype
132125
def onnx_type(self) -> _C_onnx.TensorProtoDataType:
133126
"""Convert a ScalarType to an ONNX data type."""
134127
if self not in _SCALAR_TYPE_TO_ONNX:
135128
raise ValueError(f"Scalar type {self} cannot be converted to ONNX")
136129
return _SCALAR_TYPE_TO_ONNX[self]
137130

138-
@_beartype.beartype
139131
def onnx_compatible(self) -> bool:
140132
"""Return whether this ScalarType is compatible with ONNX."""
141133
return (
@@ -145,13 +137,11 @@ def onnx_compatible(self) -> bool:
145137
)
146138

147139

148-
@_beartype.beartype
149140
def valid_scalar_name(scalar_name: Union[ScalarName, str]) -> bool:
150141
"""Return whether the given scalar name is a valid JIT scalar type name."""
151142
return scalar_name in _SCALAR_NAME_TO_TYPE
152143

153144

154-
@_beartype.beartype
155145
def valid_torch_name(torch_name: Union[TorchName, str]) -> bool:
156146
"""Return whether the given torch name is a valid torch type name."""
157147
return torch_name in _TORCH_NAME_TO_SCALAR_TYPE

0 commit comments

Comments
 (0)