From 3016daabb0b525d59f0131d89184aeb6d7d8ba80 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 21 Oct 2024 17:10:51 -0700 Subject: [PATCH] [IR] Support float4e2m1 (#1908) Support the float4e2m1 dtype from IRv11 (which is not yet released). This allows our tests to pass in the weekly-onnx CI. We use the ml_dtypes.float4_e2m1fn type for numpy conversion. Since ml_dtypes.float4_e2m1fn is only available in the latest ml_dtypes release which has dropped support for python 3.8, I used a conditional logic to build the numpy dtype mapping table. --- onnxscript/ir/_core.py | 24 ++++++++++++++++++++--- onnxscript/ir/_core_test.py | 36 ++++++++++++++++++++++++++++++---- onnxscript/ir/_enums.py | 9 +++++++++ onnxscript/ir/_enums_test.py | 2 ++ onnxscript/ir/_type_casting.py | 15 ++++++++++++++ onnxscript/ir/serde.py | 3 +++ 6 files changed, 82 insertions(+), 7 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 25722d7ba..30d88cef9 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -70,6 +70,7 @@ _enums.DataType.FLOAT8E5M2FNUZ, _enums.DataType.INT4, _enums.DataType.UINT4, + _enums.DataType.FLOAT4E2M1, ) ) @@ -182,7 +183,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) When the dtype is not one of the numpy native dtypes, the value needs need to be: - ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits. - - ``uint8`` for uint4. + - ``uint8`` for uint4 or float4. - ``uint8`` for 8-bit data types. - ``uint16`` for bfloat16 @@ -213,6 +214,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) raise TypeError( f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}." ) + if dtype == _enums.DataType.FLOAT4E2M1: + if array.dtype not in (np.uint8, ml_dtypes.float4_e2m1fn): + raise TypeError( + f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}." + ) return try: @@ -256,6 +262,8 @@ def _maybe_view_np_array_with_ml_dtypes( return array.view(ml_dtypes.int4) if dtype == _enums.DataType.UINT4: return array.view(ml_dtypes.uint4) + if dtype == _enums.DataType.FLOAT4E2M1: + return array.view(ml_dtypes.float4_e2m1fn) return array @@ -431,7 +439,11 @@ def tobytes(self) -> bytes: """ # TODO(justinchuby): Support DLPack array = self.numpy() - if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}: + if self.dtype in { + _enums.DataType.INT4, + _enums.DataType.UINT4, + _enums.DataType.FLOAT4E2M1, + }: # Pack the array into int4 array = _type_casting.pack_int4(array) else: @@ -609,7 +621,11 @@ def _load(self): ) # Handle the byte order correctly by always using little endian dt = np.dtype(self.dtype.numpy()).newbyteorder("<") - if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}: + if self.dtype in { + _enums.DataType.INT4, + _enums.DataType.UINT4, + _enums.DataType.FLOAT4E2M1, + }: # Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values dt = np.dtype(np.uint8).newbyteorder("<") count = self.size // 2 + self.size % 2 @@ -622,6 +638,8 @@ def _load(self): self._array = _type_casting.unpack_int4(self._array, shape) elif self.dtype == _enums.DataType.UINT4: self._array = _type_casting.unpack_uint4(self._array, shape) + elif self.dtype == _enums.DataType.FLOAT4E2M1: + self._array = _type_casting.unpack_float4e2m1(self._array, shape) else: self._array = self._array.reshape(shape) diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 802bf39de..036139908 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -55,6 +55,7 @@ def test_init_requires_type_when_value_is_not_np_array(self): ("int4", np.int8, ir.DataType.INT4), ("int4_uint8", np.uint8, ir.DataType.INT4), ("uint4", np.uint8, ir.DataType.UINT4), + ("float4e2m1", np.uint8, ir.DataType.FLOAT4E2M1), ] ) def test_init_with_non_native_numpy_dtype(self, _: str, np_dtype, dtype: ir.DataType): @@ -131,34 +132,48 @@ def test_tobytes(self): tensor = _core.Tensor(torch_tensor, dtype=ir.DataType.FLOAT) self.assertEqual(tensor.tobytes(), array.tobytes()) - def test_tobtyes_returns_packed_data_for_int4(self): + def test_tobytes_returns_packed_data_for_int4(self): array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=np.int8) # Test odd sized array assert len(array) % 2 == 1 tensor = _core.Tensor(array, dtype=ir.DataType.INT4) self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01") - def test_tobtyes_returns_packed_data_for_int4_ml_dtypes(self): + def test_tobytes_returns_packed_data_for_int4_ml_dtypes(self): array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=ml_dtypes.int4) # Test odd sized array assert len(array) % 2 == 1 tensor = _core.Tensor(array, dtype=ir.DataType.INT4) self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01") - def test_tobtyes_returns_packed_data_for_uint4(self): + def test_tobytes_returns_packed_data_for_uint4(self): array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) # Test odd sized array assert len(array) % 2 == 1 tensor = _core.Tensor(array, dtype=ir.DataType.UINT4) self.assertEqual(tensor.tobytes(), b"\x10r\x0f") - def test_tobtyes_returns_packed_data_for_uint4_ml_dtypes(self): + def test_tobytes_returns_packed_data_for_uint4_ml_dtypes(self): array = np.array([0, 1, 2, 7, 15], dtype=ml_dtypes.uint4) # Test odd sized array assert len(array) % 2 == 1 tensor = _core.Tensor(array, dtype=ir.DataType.UINT4) self.assertEqual(tensor.tobytes(), b"\x10r\x0f") + def test_tobytes_returns_packed_data_for_float4e2m1(self): + array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) + # Test odd sized array + assert len(array) % 2 == 1 + tensor = _core.Tensor(array, dtype=ir.DataType.FLOAT4E2M1) + self.assertEqual(tensor.tobytes(), b"\x10r\x0f") + + def test_tobytes_returns_packed_data_for_float4e2m1_ml_dtypes(self): + array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) + # Test odd sized array + assert len(array) % 2 == 1 + tensor = _core.Tensor(array, dtype=ir.DataType.FLOAT4E2M1) + self.assertEqual(tensor.tobytes(), b"\x10r\x0f") + def test_metadata(self): array = np.random.rand(1, 2).astype(np.float32) tensor = _core.Tensor(array) @@ -444,6 +459,19 @@ def test_external_tensor_complex(self, _: str, np_dtype: np.dtype): # about permission errors del tensor + def test_external_tensor_float4e2m1(self): + expected_array = np.array([0, 1, 2, 7, 15]).view(ml_dtypes.float4_e2m1fn) + tensor_proto = ir.serde.serialize_tensor( + ir.Tensor(expected_array, dtype=ir.DataType.FLOAT4E2M1) + ) + with tempfile.TemporaryDirectory() as temp_dir: + _to_external_tensor(tensor_proto, temp_dir, "tensor.bin") + tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir) + np.testing.assert_array_equal(tensor.numpy(), expected_array) + # Close the mmap file by deleting the reference to tensor so Windows doesn't complain + # about permission errors + del tensor + def test_external_tensor_empty_tensor(self): expected_array = np.array([], dtype=np.float32) tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array)) diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py index d561ad58d..d0d8c1927 100644 --- a/onnxscript/ir/_enums.py +++ b/onnxscript/ir/_enums.py @@ -64,6 +64,7 @@ class DataType(enum.IntEnum): FLOAT8E5M2FNUZ = 20 UINT4 = 21 INT4 = 22 + FLOAT4E2M1 = 23 @classmethod def from_numpy(cls, dtype: np.dtype) -> DataType: @@ -121,6 +122,7 @@ def __str__(self) -> str: DataType.FLOAT8E5M2FNUZ: 1, DataType.UINT4: 0.5, DataType.INT4: 0.5, + DataType.FLOAT4E2M1: 0.5, } @@ -150,5 +152,12 @@ def __str__(self) -> str: np.dtype(ml_dtypes.uint4): DataType.UINT4, } +# TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE +_NP_TYPE_TO_DATA_TYPE.update( + {np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1} + if hasattr(ml_dtypes, "float4_e2m1fn") + else {} +) + # ONNX DataType to Numpy dtype. _DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()} diff --git a/onnxscript/ir/_enums_test.py b/onnxscript/ir/_enums_test.py index 661681920..0721aaa99 100644 --- a/onnxscript/ir/_enums_test.py +++ b/onnxscript/ir/_enums_test.py @@ -32,6 +32,8 @@ def test_enums_are_the_same_as_spec(self): self.assertEqual(_enums.DataType.FLOAT8E5M2FNUZ, onnx.TensorProto.FLOAT8E5M2FNUZ) self.assertEqual(_enums.DataType.UINT4, onnx.TensorProto.UINT4) self.assertEqual(_enums.DataType.INT4, onnx.TensorProto.INT4) + if hasattr(onnx.TensorProto, "FLOAT4E2M1"): + self.assertEqual(_enums.DataType.FLOAT4E2M1, onnx.TensorProto.FLOAT4E2M1) self.assertEqual(_enums.DataType.UNDEFINED, onnx.TensorProto.UNDEFINED) def test_from_numpy_takes_np_dtype_and_returns_data_type(self): diff --git a/onnxscript/ir/_type_casting.py b/onnxscript/ir/_type_casting.py index 3f3611000..20bab6903 100644 --- a/onnxscript/ir/_type_casting.py +++ b/onnxscript/ir/_type_casting.py @@ -89,3 +89,18 @@ def unpack_int4( """ unpacked = _unpack_uint4_as_uint8(data, dims) return _extend_int4_sign_bits(unpacked).view(ml_dtypes.int4) + + +def unpack_float4e2m1( + data: npt.NDArray[np.uint8], dims: Sequence[int] +) -> npt.NDArray[ml_dtypes.float4_e2m1fn]: + """Convert a packed float4e2m1 array to unpacked float4e2m1 array. + + Args: + data: A numpy array. + dims: The dimensions are used to reshape the unpacked buffer. + + Returns: + A numpy array of float32 reshaped to dims. + """ + return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 41571bcd3..2d3a9849e 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -323,6 +323,8 @@ def numpy(self) -> np.ndarray: return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims) elif dtype == _enums.DataType.UINT4: return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims) + elif dtype == _enums.DataType.FLOAT4E2M1: + return _type_casting.unpack_float4e2m1(array.astype(np.uint8), self._proto.dims) else: # Otherwise convert to the correct dtype and reshape # Note we cannot use view() here because the storage dtype may not be the same size as the target @@ -369,6 +371,7 @@ def tobytes(self) -> bytes: _enums.DataType.FLOAT8E5M2FNUZ, _enums.DataType.INT4, _enums.DataType.UINT4, + _enums.DataType.FLOAT4E2M1, }: # uint4 and int4 values are already packed, even when stored as int32 # so we don't need to pack them again