Skip to content

Commit 20c77e2

Browse files
committed
use torch.float8_e8m0fnu in mx_formats
Summary: Switches our MX code to use the new `torch.float8_e8m0fnu` dtype directly where appropriate. This will allow for easier numerical debugging of scales, as we can easily see the numerical values when they are printed. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 1c0c2ba ghstack-comment-id: 2721705057 Pull Request resolved: #1882
1 parent c376285 commit 20c77e2

File tree

5 files changed

+25
-13
lines changed

5 files changed

+25
-13
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525
)
2626
from torchao.quantization.utils import compute_error
2727
from torchao.utils import (
28-
TORCH_VERSION_AT_LEAST_2_5,
28+
TORCH_VERSION_AT_LEAST_2_7,
2929
is_sm_at_least_89,
3030
is_sm_at_least_100,
3131
)
3232

3333
torch.manual_seed(2)
3434

35-
if not TORCH_VERSION_AT_LEAST_2_5:
35+
if not TORCH_VERSION_AT_LEAST_2_7:
3636
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
3737

3838

test/prototype/mx_formats/test_mx_mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
66
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
77
from torchao.prototype.mx_formats.utils import to_blocked
8-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100
8+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100
99

10-
if not TORCH_VERSION_AT_LEAST_2_4:
10+
if not TORCH_VERSION_AT_LEAST_2_7:
1111
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1212

1313

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,20 @@
1818
)
1919
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
2020
from torchao.prototype.mx_formats.mx_tensor import (
21-
E8M0_EXPONENT_NAN_VAL,
2221
MXTensor,
2322
ScaleCalculationMode,
2423
to_dtype,
2524
)
2625
from torchao.quantization.utils import compute_error
2726
from torchao.utils import (
28-
TORCH_VERSION_AT_LEAST_2_4,
27+
TORCH_VERSION_AT_LEAST_2_7,
2928
is_sm_at_least_89,
3029
is_sm_at_least_100,
3130
)
3231

3332
torch.manual_seed(2)
3433

35-
if not TORCH_VERSION_AT_LEAST_2_4:
34+
if not TORCH_VERSION_AT_LEAST_2_7:
3635
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
3736

3837

@@ -118,8 +117,8 @@ def test_exponent_nan_in(elem_dtype):
118117
)
119118
block_size = 4
120119
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
121-
assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL)
122-
assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL)
120+
assert torch.all(torch.isnan(tensor_mx._scale_e8m0[0]))
121+
assert not torch.any(torch.isnan(tensor_mx._scale_e8m0[1:]))
123122

124123

125124
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -129,8 +128,11 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
129128
"""
130129
If block exponent value is NaN, the MX tensor block value is NaN
131130
"""
132-
scale_e8m0_bits = torch.tensor(
133-
[E8M0_EXPONENT_NAN_VAL, 23], dtype=torch.uint8, device="cuda"
131+
if pack_fp6 and elem_dtype not in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2):
132+
pytest.skip("invalid configuration")
133+
134+
scale_e8m0 = torch.tensor(
135+
[float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device="cuda"
134136
)
135137

136138
block_size = 4
@@ -156,7 +158,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
156158
block_size = 4
157159
use_fp4_custom_triton_dequant_kernel = False
158160
tensor_mx = MXTensor(
159-
scale_e8m0_bits,
161+
scale_e8m0,
160162
data_bits,
161163
elem_dtype,
162164
block_size,

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,7 @@ def triton_f4_to_scaled_bf16(
743743
size is currently assumed to be 32.
744744
Output: a tensor of bfloat16 values, multiplied by the encoded scale
745745
"""
746+
s_e8m0 = s_e8m0.view(torch.uint8)
746747
assert TORCH_VERSION_AT_LEAST_2_4, "unsupported"
747748
new_shape = (*x.shape[:-1], x.shape[-1] * 2)
748749
output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16)
@@ -859,6 +860,7 @@ def triton_f6_e2m3_to_scaled_bf16(
859860
size is currently assumed to be 32.
860861
Output: a tensor of bfloat16 values, multiplied by the encoded scale
861862
"""
863+
s_e8m0 = s_e8m0.view(torch.uint8)
862864

863865
packed_mx_block_size = 3 * mx_block_size // 4
864866

@@ -900,6 +902,7 @@ def triton_f6_e3m2_to_scaled_bf16(
900902
size is currently assumed to be 32.
901903
Output: a tensor of bfloat16 values, multiplied by the encoded scale
902904
"""
905+
s_e8m0 = s_e8m0.view(torch.uint8)
903906

904907
packed_mx_block_size = 3 * mx_block_size // 4
905908

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,15 @@ def to_mx(
240240
else:
241241
raise AssertionError("unsupported")
242242

243+
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
244+
243245
return scale_e8m0_biased, data_lp
244246

245247

248+
# TODO(future PR): delete this function once casting from e8m0 to float works
249+
# in triton + torchinductor
246250
def get_fp_scale(scale_e8m0):
251+
scale_e8m0 = scale_e8m0.view(torch.uint8)
247252
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
248253
# TODO(later): it would be nice if there was a way to do the 2^x operation
249254
# in PyTorch without creating a tensor of twos
@@ -476,7 +481,9 @@ def __new__(
476481
dtype=orig_dtype,
477482
device=data_bits.device,
478483
)
479-
assert scale_e8m0_bits.dtype == torch.uint8, "unsupported"
484+
assert (
485+
scale_e8m0_bits.dtype == torch.float8_e8m0fnu
486+
), f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}"
480487
assert len(scale_e8m0_bits.shape) == 1, "unsupported"
481488
assert data_bits.dtype in (
482489
torch.float8_e4m3fn,

0 commit comments

Comments
 (0)