Skip to content

Commit df26586

Browse files
authored
Update TorchTensor to use ml_dtypes (#2201)
Bring changes from pytorch/pytorch#151259 to correctly support bfloat16 and float8* types.
1 parent 04ed2b8 commit df26586

File tree

3 files changed

+66
-41
lines changed

3 files changed

+66
-41
lines changed

docs/intermediate_representation/tensors.md

+57-33
Original file line numberDiff line numberDiff line change
@@ -192,64 +192,88 @@ To fully support arrays from other frameworks, it is usually a good idea to crea
192192
import ctypes
193193
from typing import Any
194194
195+
import numpy.typing as npt
195196
import torch
197+
196198
from onnxscript import ir
197199
198-
# Define utilities to convert PyTorch data types so users do not need to specify manually
199-
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
200-
torch.bfloat16: ir.DataType.BFLOAT16,
201-
torch.bool: ir.DataType.BOOL,
202-
torch.complex128: ir.DataType.COMPLEX128,
203-
torch.complex64: ir.DataType.COMPLEX64,
204-
torch.float16: ir.DataType.FLOAT16,
205-
torch.float32: ir.DataType.FLOAT,
206-
torch.float64: ir.DataType.DOUBLE,
207-
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
208-
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
209-
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
210-
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
211-
torch.int16: ir.DataType.INT16,
212-
torch.int32: ir.DataType.INT32,
213-
torch.int64: ir.DataType.INT64,
214-
torch.int8: ir.DataType.INT8,
215-
torch.uint8: ir.DataType.UINT8,
216-
}
217-
218-
219-
def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
220-
return _TORCH_DTYPE_TO_ONNX[dtype]
221200
222201
class TorchTensor(ir.Tensor):
223-
def __init__(self, tensor: torch.Tensor):
202+
def __init__(
203+
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
204+
):
224205
# Pass the tensor as the raw data to ir.Tensor's constructor
225-
super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype))
226206
227-
def __array__(self, dtype: Any = None) -> "np.ndarray":
228-
# numpy() calls __array__ in ir.Tensor
207+
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {
208+
torch.bfloat16: ir.DataType.BFLOAT16,
209+
torch.bool: ir.DataType.BOOL,
210+
torch.complex128: ir.DataType.COMPLEX128,
211+
torch.complex64: ir.DataType.COMPLEX64,
212+
torch.float16: ir.DataType.FLOAT16,
213+
torch.float32: ir.DataType.FLOAT,
214+
torch.float64: ir.DataType.DOUBLE,
215+
torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,
216+
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
217+
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
218+
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
219+
torch.int16: ir.DataType.INT16,
220+
torch.int32: ir.DataType.INT32,
221+
torch.int64: ir.DataType.INT64,
222+
torch.int8: ir.DataType.INT8,
223+
torch.uint8: ir.DataType.UINT8,
224+
torch.uint16: ir.DataType.UINT16,
225+
torch.uint32: ir.DataType.UINT32,
226+
torch.uint64: ir.DataType.UINT64,
227+
}
228+
super().__init__(
229+
tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string
230+
)
231+
232+
def numpy(self) -> npt.NDArray:
233+
self.raw: torch.Tensor
229234
if self.dtype == ir.DataType.BFLOAT16:
230-
return self.raw.view(torch.uint16).__array__(dtype)
235+
return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
231236
if self.dtype in {
232237
ir.DataType.FLOAT8E4M3FN,
233238
ir.DataType.FLOAT8E4M3FNUZ,
234239
ir.DataType.FLOAT8E5M2,
235-
ir.DataType.FLOAT8E5M2FNUZ
240+
ir.DataType.FLOAT8E5M2FNUZ,
236241
}:
237-
return self.raw.view(torch.uint8).__array__(dtype)
238-
return self.raw.__array__(dtype)
242+
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
243+
244+
return self.raw.numpy(force=True)
245+
246+
def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:
247+
del copy # Unused, but needed for the signature
248+
if dtype is None:
249+
return self.numpy()
250+
return self.numpy().__array__(dtype)
239251
240252
def tobytes(self) -> bytes:
241253
# Implement tobytes to support native PyTorch types so we can use types like bloat16
242254
# Reading from memory directly is also more efficient because
243255
# it avoids copying to a NumPy array
244-
tensor = self.raw.detach().cpu().contiguous()
256+
import torch._subclasses.fake_tensor
257+
258+
with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access
259+
# Disable any fake mode so calling detach() etc. will return a real tensor
260+
tensor = self.raw.detach().cpu().contiguous()
261+
262+
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): # pylint: disable=protected-access
263+
raise TypeError(
264+
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
265+
"with a tensor backed by real data using ONNXProgram.apply_weights() "
266+
"or save the model without initializers by setting include_initializers=False."
267+
)
268+
245269
return bytes(
246270
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
247271
tensor.data_ptr()
248272
)
249273
)
250274
251275
# Test the implementation
252-
torch_tensor = torch.tensor([1,2,3], dtype=torch.bfloat16)
276+
torch_tensor = torch.tensor([1, 2, 3], dtype=torch.bfloat16)
253277
tensor = TorchTensor(torch_tensor)
254278
print("tensor: ", tensor)
255279
print("numpy: ", tensor.numpy())

onnxscript/ir/tensor_adapters.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ def numpy(self) -> npt.NDArray:
8181

8282
self.raw: torch.Tensor
8383
if self.dtype == ir.DataType.BFLOAT16:
84-
return self.raw.view(torch.uint16).numpy(force=True)
84+
return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
8585
if self.dtype in {
8686
ir.DataType.FLOAT8E4M3FN,
8787
ir.DataType.FLOAT8E4M3FNUZ,
8888
ir.DataType.FLOAT8E5M2,
8989
ir.DataType.FLOAT8E5M2FNUZ,
9090
}:
91-
# TODO: Use ml_dtypes
92-
return self.raw.view(torch.uint8).numpy(force=True)
91+
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
92+
9393
return self.raw.numpy(force=True)
9494

9595
def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:

onnxscript/ir/tensor_adapters_test.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import importlib.util
88
import unittest
99

10+
import ml_dtypes
1011
import numpy as np
1112
import parameterized
1213
import torch
@@ -25,17 +26,17 @@ def skip_if_no(module_name: str):
2526
class TorchTensorTest(unittest.TestCase):
2627
@parameterized.parameterized.expand(
2728
[
28-
(torch.bfloat16, np.uint16),
29+
(torch.bfloat16, ml_dtypes.bfloat16),
2930
(torch.bool, np.bool_),
3031
(torch.complex128, np.complex128),
3132
(torch.complex64, np.complex64),
3233
(torch.float16, np.float16),
3334
(torch.float32, np.float32),
3435
(torch.float64, np.float64),
35-
(torch.float8_e4m3fn, np.uint8),
36-
(torch.float8_e4m3fnuz, np.uint8),
37-
(torch.float8_e5m2, np.uint8),
38-
(torch.float8_e5m2fnuz, np.uint8),
36+
(torch.float8_e4m3fn, ml_dtypes.float8_e4m3fn),
37+
(torch.float8_e4m3fnuz, ml_dtypes.float8_e4m3fnuz),
38+
(torch.float8_e5m2, ml_dtypes.float8_e5m2),
39+
(torch.float8_e5m2fnuz, ml_dtypes.float8_e5m2fnuz),
3940
(torch.int16, np.int16),
4041
(torch.int32, np.int32),
4142
(torch.int64, np.int64),

0 commit comments

Comments
 (0)