9
9
import onnx
10
10
import onnx .helper
11
11
12
- DType = onnx . TensorProto . DataType
12
+ import onnxscript . ir
13
13
14
- DimType = Union [int , str , type (None )]
14
+ _DType = onnxscript .ir .DataType
15
+ _DimType = Union [int , str , type (None )]
16
+ _ShapeType = Union [Tuple [_DimType , ...], _DimType , type (Ellipsis )]
15
17
18
+ _tensor_type_shape_cache : dict [_DType , TensorType ] = {}
19
+ tensor_type_registry : dict [_DType , TensorType ] = {}
16
20
17
- def check_dim (dim ):
21
+
22
+ def _check_dim (dim ):
18
23
if not isinstance (dim , (int , str , type (None ))):
19
24
raise TypeError (f"Invalid dimension { dim } " )
20
25
21
26
22
- ShapeType = Union [Tuple [DimType , ...], DimType , type (Ellipsis )]
23
-
24
-
25
- def check_shape (shape ):
27
+ def _check_shape (shape ):
26
28
if isinstance (shape , tuple ):
27
29
for dim in shape :
28
- check_dim (dim )
30
+ _check_dim (dim )
29
31
elif shape != Ellipsis :
30
- check_dim (shape )
31
-
32
-
33
- tensor_type_registry : dict [DType , TensorType ] = {}
34
- _tensor_type_shape_cache : dict [DType , TensorType ] = {}
32
+ _check_dim (shape )
35
33
36
34
37
35
class TensorType (abc .ABC ):
@@ -58,13 +56,13 @@ class TensorType(abc.ABC):
58
56
tensor: FLOAT[128, 1024]
59
57
"""
60
58
61
- dtype : ClassVar [DType ]
62
- shape : ClassVar [Optional [ShapeType ]]
59
+ dtype : ClassVar [_DType ]
60
+ shape : ClassVar [Optional [_ShapeType ]]
63
61
64
62
def __new__ (cls ):
65
63
raise NotImplementedError ("TensorTypes cannot be instantiated" )
66
64
67
- def __init_subclass__ (cls , dtype : DType , shape : Optional [ShapeType ] = None ):
65
+ def __init_subclass__ (cls , dtype : _DType , shape : Optional [_ShapeType ] = None ):
68
66
cls .dtype = dtype
69
67
cls .shape = shape
70
68
if shape is None :
@@ -76,9 +74,9 @@ def __init_subclass__(cls, dtype: DType, shape: Optional[ShapeType] = None):
76
74
)
77
75
tensor_type_registry [dtype ] = cls
78
76
else :
79
- check_shape (shape )
77
+ _check_shape (shape )
80
78
81
- def __class_getitem__ (cls , shape : Optional [ShapeType ]) -> type [TensorType ]:
79
+ def __class_getitem__ (cls , shape : Optional [_ShapeType ]) -> type [TensorType ]:
82
80
if cls .shape is not None :
83
81
raise ValueError ("Invalid usage: shape already specified." )
84
82
if shape is None :
@@ -108,83 +106,91 @@ def to_string(cls) -> str:
108
106
return f"tensor({ cls .__name__ .lower ()} )"
109
107
110
108
111
- class FLOAT (TensorType , dtype = onnx .TensorProto .FLOAT ):
109
+ class FLOAT (TensorType , dtype = onnxscript .ir .DataType .FLOAT ):
110
+ pass
111
+
112
+
113
+ class UINT8 (TensorType , dtype = onnxscript .ir .DataType .UINT8 ):
114
+ pass
115
+
116
+
117
+ class INT8 (TensorType , dtype = onnxscript .ir .DataType .INT8 ):
112
118
pass
113
119
114
120
115
- class UINT8 (TensorType , dtype = onnx . TensorProto . UINT8 ):
121
+ class UINT16 (TensorType , dtype = onnxscript . ir . DataType . UINT16 ):
116
122
pass
117
123
118
124
119
- class INT8 (TensorType , dtype = onnx . TensorProto . INT8 ):
125
+ class INT16 (TensorType , dtype = onnxscript . ir . DataType . INT16 ):
120
126
pass
121
127
122
128
123
- class UINT16 (TensorType , dtype = onnx . TensorProto . UINT16 ):
129
+ class INT32 (TensorType , dtype = onnxscript . ir . DataType . INT32 ):
124
130
pass
125
131
126
132
127
- class INT16 (TensorType , dtype = onnx . TensorProto . INT16 ):
133
+ class INT64 (TensorType , dtype = onnxscript . ir . DataType . INT64 ):
128
134
pass
129
135
130
136
131
- class INT32 (TensorType , dtype = onnx . TensorProto . INT32 ):
137
+ class STRING (TensorType , dtype = onnxscript . ir . DataType . STRING ):
132
138
pass
133
139
134
140
135
- class INT64 (TensorType , dtype = onnx . TensorProto . INT64 ):
141
+ class BOOL (TensorType , dtype = onnxscript . ir . DataType . BOOL ):
136
142
pass
137
143
138
144
139
- class STRING (TensorType , dtype = onnx . TensorProto . STRING ):
145
+ class FLOAT16 (TensorType , dtype = onnxscript . ir . DataType . FLOAT16 ):
140
146
pass
141
147
142
148
143
- class BOOL (TensorType , dtype = onnx . TensorProto . BOOL ):
149
+ class DOUBLE (TensorType , dtype = onnxscript . ir . DataType . DOUBLE ):
144
150
pass
145
151
146
152
147
- class FLOAT16 (TensorType , dtype = onnx . TensorProto . FLOAT16 ):
153
+ class UINT32 (TensorType , dtype = onnxscript . ir . DataType . UINT32 ):
148
154
pass
149
155
150
156
151
- class DOUBLE (TensorType , dtype = onnx . TensorProto . DOUBLE ):
157
+ class UINT64 (TensorType , dtype = onnxscript . ir . DataType . UINT64 ):
152
158
pass
153
159
154
160
155
- class UINT32 (TensorType , dtype = onnx . TensorProto . UINT32 ):
161
+ class COMPLEX64 (TensorType , dtype = onnxscript . ir . DataType . COMPLEX64 ):
156
162
pass
157
163
158
164
159
- class UINT64 (TensorType , dtype = onnx . TensorProto . UINT64 ):
165
+ class COMPLEX128 (TensorType , dtype = onnxscript . ir . DataType . COMPLEX128 ):
160
166
pass
161
167
162
168
163
- class COMPLEX64 (TensorType , dtype = onnx . TensorProto . COMPLEX64 ):
169
+ class BFLOAT16 (TensorType , dtype = onnxscript . ir . DataType . BFLOAT16 ):
164
170
pass
165
171
166
172
167
- class COMPLEX128 (TensorType , dtype = onnx . TensorProto . COMPLEX128 ):
173
+ class FLOAT8E4M3FN (TensorType , dtype = onnxscript . ir . DataType . FLOAT8E4M3FN ):
168
174
pass
169
175
170
176
171
- class BFLOAT16 (TensorType , dtype = onnx . TensorProto . BFLOAT16 ):
177
+ class FLOAT8E4M3FNUZ (TensorType , dtype = onnxscript . ir . DataType . FLOAT8E4M3FNUZ ):
172
178
pass
173
179
174
180
175
- class FLOAT8E4M3FN (TensorType , dtype = onnx . TensorProto . FLOAT8E4M3FN ):
181
+ class FLOAT8E5M2 (TensorType , dtype = onnxscript . ir . DataType . FLOAT8E5M2 ):
176
182
pass
177
183
178
184
179
- class FLOAT8E4M3FNUZ (TensorType , dtype = onnx . TensorProto . FLOAT8E4M3FNUZ ):
185
+ class FLOAT8E5M2FNUZ (TensorType , dtype = onnxscript . ir . DataType . FLOAT8E5M2FNUZ ):
180
186
pass
181
187
182
188
183
- class FLOAT8E5M2 (TensorType , dtype = onnx . TensorProto . FLOAT8E5M2 ):
189
+ class INT4 (TensorType , dtype = onnxscript . ir . DataType . INT4 ):
184
190
pass
185
191
186
192
187
- class FLOAT8E5M2FNUZ (TensorType , dtype = onnx . TensorProto . FLOAT8E5M2FNUZ ):
193
+ class UINT4 (TensorType , dtype = onnxscript . ir . DataType . UINT4 ):
188
194
pass
189
195
190
196
0 commit comments