@@ -192,64 +192,88 @@ To fully support arrays from other frameworks, it is usually a good idea to crea
192
192
import ctypes
193
193
from typing import Any
194
194
195
+ import numpy.typing as npt
195
196
import torch
197
+
196
198
from onnxscript import ir
197
199
198
- # Define utilities to convert PyTorch data types so users do not need to specify manually
199
- _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
200
- torch.bfloat16: ir.DataType.BFLOAT16,
201
- torch.bool: ir.DataType.BOOL,
202
- torch.complex128: ir.DataType.COMPLEX128,
203
- torch.complex64: ir.DataType.COMPLEX64,
204
- torch.float16: ir.DataType.FLOAT16,
205
- torch.float32: ir.DataType.FLOAT,
206
- torch.float64: ir.DataType.DOUBLE,
207
- torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
208
- torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
209
- torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
210
- torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
211
- torch.int16: ir.DataType.INT16,
212
- torch.int32: ir.DataType.INT32,
213
- torch.int64: ir.DataType.INT64,
214
- torch.int8: ir.DataType.INT8,
215
- torch.uint8: ir.DataType.UINT8,
216
- }
217
-
218
-
219
- def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
220
- return _TORCH_DTYPE_TO_ONNX[dtype]
221
200
222
201
class TorchTensor(ir.Tensor):
223
- def __init__(self, tensor: torch.Tensor):
202
+ def __init__(
203
+ self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
204
+ ):
224
205
# Pass the tensor as the raw data to ir.Tensor's constructor
225
- super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype))
226
206
227
- def __array__(self, dtype: Any = None) -> "np.ndarray":
228
- # numpy() calls __array__ in ir.Tensor
207
+ _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
208
+ torch.bfloat16: ir.DataType.BFLOAT16,
209
+ torch.bool: ir.DataType.BOOL,
210
+ torch.complex128: ir.DataType.COMPLEX128,
211
+ torch.complex64: ir.DataType.COMPLEX64,
212
+ torch.float16: ir.DataType.FLOAT16,
213
+ torch.float32: ir.DataType.FLOAT,
214
+ torch.float64: ir.DataType.DOUBLE,
215
+ torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
216
+ torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
217
+ torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
218
+ torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
219
+ torch.int16: ir.DataType.INT16,
220
+ torch.int32: ir.DataType.INT32,
221
+ torch.int64: ir.DataType.INT64,
222
+ torch.int8: ir.DataType.INT8,
223
+ torch.uint8: ir.DataType.UINT8,
224
+ torch.uint16: ir.DataType.UINT16,
225
+ torch.uint32: ir.DataType.UINT32,
226
+ torch.uint64: ir.DataType.UINT64,
227
+ }
228
+ super().__init__(
229
+ tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
230
+ )
231
+
232
+ def numpy(self) -> npt.NDArray:
233
+ self.raw: torch.Tensor
229
234
if self.dtype == ir.DataType.BFLOAT16:
230
- return self.raw.view(torch.uint16).__array__( dtype)
235
+ return self.raw.view(torch.uint16).numpy(force=True).view(self. dtype.numpy() )
231
236
if self.dtype in {
232
237
ir.DataType.FLOAT8E4M3FN,
233
238
ir.DataType.FLOAT8E4M3FNUZ,
234
239
ir.DataType.FLOAT8E5M2,
235
- ir.DataType.FLOAT8E5M2FNUZ
240
+ ir.DataType.FLOAT8E5M2FNUZ,
236
241
}:
237
- return self.raw.view(torch.uint8).__array__(dtype)
238
- return self.raw.__array__(dtype)
242
+ return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
243
+
244
+ return self.raw.numpy(force=True)
245
+
246
+ def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:
247
+ del copy # Unused, but needed for the signature
248
+ if dtype is None:
249
+ return self.numpy()
250
+ return self.numpy().__array__(dtype)
239
251
240
252
def tobytes(self) -> bytes:
241
253
# Implement tobytes to support native PyTorch types so we can use types like bloat16
242
254
# Reading from memory directly is also more efficient because
243
255
# it avoids copying to a NumPy array
244
- tensor = self.raw.detach().cpu().contiguous()
256
+ import torch._subclasses.fake_tensor
257
+
258
+ with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access
259
+ # Disable any fake mode so calling detach() etc. will return a real tensor
260
+ tensor = self.raw.detach().cpu().contiguous()
261
+
262
+ if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): # pylint: disable=protected-access
263
+ raise TypeError(
264
+ f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
265
+ "with a tensor backed by real data using ONNXProgram.apply_weights() "
266
+ "or save the model without initializers by setting include_initializers=False."
267
+ )
268
+
245
269
return bytes(
246
270
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
247
271
tensor.data_ptr()
248
272
)
249
273
)
250
274
251
275
# Test the implementation
252
- torch_tensor = torch.tensor([1,2, 3], dtype=torch.bfloat16)
276
+ torch_tensor = torch.tensor([1, 2, 3], dtype=torch.bfloat16)
253
277
tensor = TorchTensor(torch_tensor)
254
278
print("tensor: ", tensor)
255
279
print("numpy: ", tensor.numpy())
0 commit comments