Skip to content

Commit 3a7d6fd

Browse files
authored
Use IR types to define onnx_types (microsoft#1924)
- Use IR types to define onnx_types so that it is not dependent on onnx package version. - Also add INT4 and UINT4 types. - Make some helper functions private.
1 parent ec3b140 commit 3a7d6fd

File tree

2 files changed

+47
-41
lines changed

2 files changed

+47
-41
lines changed

onnxscript/onnx_types.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,27 @@
99
import onnx
1010
import onnx.helper
1111

12-
DType = onnx.TensorProto.DataType
12+
import onnxscript.ir
1313

14-
DimType = Union[int, str, type(None)]
14+
_DType = onnxscript.ir.DataType
15+
_DimType = Union[int, str, type(None)]
16+
_ShapeType = Union[Tuple[_DimType, ...], _DimType, type(Ellipsis)]
1517

18+
_tensor_type_shape_cache: dict[_DType, TensorType] = {}
19+
tensor_type_registry: dict[_DType, TensorType] = {}
1620

17-
def check_dim(dim):
21+
22+
def _check_dim(dim):
1823
if not isinstance(dim, (int, str, type(None))):
1924
raise TypeError(f"Invalid dimension {dim}")
2025

2126

22-
ShapeType = Union[Tuple[DimType, ...], DimType, type(Ellipsis)]
23-
24-
25-
def check_shape(shape):
27+
def _check_shape(shape):
2628
if isinstance(shape, tuple):
2729
for dim in shape:
28-
check_dim(dim)
30+
_check_dim(dim)
2931
elif shape != Ellipsis:
30-
check_dim(shape)
31-
32-
33-
tensor_type_registry: dict[DType, TensorType] = {}
34-
_tensor_type_shape_cache: dict[DType, TensorType] = {}
32+
_check_dim(shape)
3533

3634

3735
class TensorType(abc.ABC):
@@ -58,13 +56,13 @@ class TensorType(abc.ABC):
5856
tensor: FLOAT[128, 1024]
5957
"""
6058

61-
dtype: ClassVar[DType]
62-
shape: ClassVar[Optional[ShapeType]]
59+
dtype: ClassVar[_DType]
60+
shape: ClassVar[Optional[_ShapeType]]
6361

6462
def __new__(cls):
6563
raise NotImplementedError("TensorTypes cannot be instantiated")
6664

67-
def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None):
65+
def __init_subclass__(cls, dtype: _DType, shape: Optional[_ShapeType] = None):
6866
cls.dtype = dtype
6967
cls.shape = shape
7068
if shape is None:
@@ -76,9 +74,9 @@ def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None):
7674
)
7775
tensor_type_registry[dtype] = cls
7876
else:
79-
check_shape(shape)
77+
_check_shape(shape)
8078

81-
def __class_getitem__(cls, shape: Optional[ShapeType]) -> type[TensorType]:
79+
def __class_getitem__(cls, shape: Optional[_ShapeType]) -> type[TensorType]:
8280
if cls.shape is not None:
8381
raise ValueError("Invalid usage: shape already specified.")
8482
if shape is None:
@@ -108,83 +106,91 @@ def to_string(cls) -> str:
108106
return f"tensor({cls.__name__.lower()})"
109107

110108

111-
class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT):
109+
class FLOAT(TensorType, dtype=onnxscript.ir.DataType.FLOAT):
110+
pass
111+
112+
113+
class UINT8(TensorType, dtype=onnxscript.ir.DataType.UINT8):
114+
pass
115+
116+
117+
class INT8(TensorType, dtype=onnxscript.ir.DataType.INT8):
112118
pass
113119

114120

115-
class UINT8(TensorType, dtype=onnx.TensorProto.UINT8):
121+
class UINT16(TensorType, dtype=onnxscript.ir.DataType.UINT16):
116122
pass
117123

118124

119-
class INT8(TensorType, dtype=onnx.TensorProto.INT8):
125+
class INT16(TensorType, dtype=onnxscript.ir.DataType.INT16):
120126
pass
121127

122128

123-
class UINT16(TensorType, dtype=onnx.TensorProto.UINT16):
129+
class INT32(TensorType, dtype=onnxscript.ir.DataType.INT32):
124130
pass
125131

126132

127-
class INT16(TensorType, dtype=onnx.TensorProto.INT16):
133+
class INT64(TensorType, dtype=onnxscript.ir.DataType.INT64):
128134
pass
129135

130136

131-
class INT32(TensorType, dtype=onnx.TensorProto.INT32):
137+
class STRING(TensorType, dtype=onnxscript.ir.DataType.STRING):
132138
pass
133139

134140

135-
class INT64(TensorType, dtype=onnx.TensorProto.INT64):
141+
class BOOL(TensorType, dtype=onnxscript.ir.DataType.BOOL):
136142
pass
137143

138144

139-
class STRING(TensorType, dtype=onnx.TensorProto.STRING):
145+
class FLOAT16(TensorType, dtype=onnxscript.ir.DataType.FLOAT16):
140146
pass
141147

142148

143-
class BOOL(TensorType, dtype=onnx.TensorProto.BOOL):
149+
class DOUBLE(TensorType, dtype=onnxscript.ir.DataType.DOUBLE):
144150
pass
145151

146152

147-
class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16):
153+
class UINT32(TensorType, dtype=onnxscript.ir.DataType.UINT32):
148154
pass
149155

150156

151-
class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE):
157+
class UINT64(TensorType, dtype=onnxscript.ir.DataType.UINT64):
152158
pass
153159

154160

155-
class UINT32(TensorType, dtype=onnx.TensorProto.UINT32):
161+
class COMPLEX64(TensorType, dtype=onnxscript.ir.DataType.COMPLEX64):
156162
pass
157163

158164

159-
class UINT64(TensorType, dtype=onnx.TensorProto.UINT64):
165+
class COMPLEX128(TensorType, dtype=onnxscript.ir.DataType.COMPLEX128):
160166
pass
161167

162168

163-
class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64):
169+
class BFLOAT16(TensorType, dtype=onnxscript.ir.DataType.BFLOAT16):
164170
pass
165171

166172

167-
class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128):
173+
class FLOAT8E4M3FN(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E4M3FN):
168174
pass
169175

170176

171-
class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16):
177+
class FLOAT8E4M3FNUZ(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E4M3FNUZ):
172178
pass
173179

174180

175-
class FLOAT8E4M3FN(TensorType, dtype=onnx.TensorProto.FLOAT8E4M3FN):
181+
class FLOAT8E5M2(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E5M2):
176182
pass
177183

178184

179-
class FLOAT8E4M3FNUZ(TensorType, dtype=onnx.TensorProto.FLOAT8E4M3FNUZ):
185+
class FLOAT8E5M2FNUZ(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E5M2FNUZ):
180186
pass
181187

182188

183-
class FLOAT8E5M2(TensorType, dtype=onnx.TensorProto.FLOAT8E5M2):
189+
class INT4(TensorType, dtype=onnxscript.ir.DataType.INT4):
184190
pass
185191

186192

187-
class FLOAT8E5M2FNUZ(TensorType, dtype=onnx.TensorProto.FLOAT8E5M2FNUZ):
193+
class UINT4(TensorType, dtype=onnxscript.ir.DataType.UINT4):
188194
pass
189195

190196

tests/onnx_types_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from parameterized import parameterized
1515

16-
from onnxscript.onnx_types import DOUBLE, FLOAT, DType, TensorType, tensor_type_registry
16+
from onnxscript.onnx_types import DOUBLE, FLOAT, TensorType, tensor_type_registry
1717

1818

1919
class TestOnnxTypes(unittest.TestCase):
@@ -26,7 +26,7 @@ def test_instantiation(self):
2626
FLOAT[...]()
2727

2828
@parameterized.expand(tensor_type_registry.items())
29-
def test_type_properties(self, dtype: DType, tensor_type: type[TensorType]):
29+
def test_type_properties(self, dtype: int, tensor_type: type[TensorType]):
3030
self.assertEqual(tensor_type.dtype, dtype)
3131
self.assertIsNone(tensor_type.shape)
3232
self.assertEqual(tensor_type[...].shape, ...) # type: ignore[index]
@@ -35,7 +35,7 @@ def test_type_properties(self, dtype: DType, tensor_type: type[TensorType]):
3535
self.assertEqual(tensor_type[1, 2, 3].dtype, dtype) # type: ignore[index]
3636

3737
@parameterized.expand([(dtype,) for dtype in tensor_type_registry])
38-
def test_dtype_bound_to_subclass(self, dtype: DType):
38+
def test_dtype_bound_to_subclass(self, dtype: int):
3939
with self.assertRaises(ValueError):
4040
type(f"InvalidTensorTypeSubclass_{dtype}", (TensorType,), {}, dtype=dtype)
4141

0 commit comments

Comments
 (0)