Skip to content

use torch.float8_e8m0fnu in mx_formats #1882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
from torchao.prototype.mx_formats.mx_tensor import (
E8M0_EXPONENT_NAN_VAL,
MXTensor,
ScaleCalculationMode,
to_dtype,
Expand Down Expand Up @@ -117,8 +116,8 @@ def test_exponent_nan_in(elem_dtype):
)
block_size = 4
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL)
assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL)
assert torch.all(torch.isnan(tensor_mx._scale_e8m0[0]))
assert not torch.any(torch.isnan(tensor_mx._scale_e8m0[1:]))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand All @@ -128,8 +127,11 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
"""
If block exponent value is NaN, the MX tensor block value is NaN
"""
scale_e8m0_bits = torch.tensor(
[E8M0_EXPONENT_NAN_VAL, 23], dtype=torch.uint8, device="cuda"
if pack_fp6 and elem_dtype not in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2):
pytest.skip("invalid configuration")

scale_e8m0 = torch.tensor(
[float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device="cuda"
)

block_size = 4
Expand All @@ -155,7 +157,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
block_size = 4
use_fp4_custom_triton_dequant_kernel = False
tensor_mx = MXTensor(
scale_e8m0_bits,
scale_e8m0,
data_bits,
elem_dtype,
block_size,
Expand Down
3 changes: 3 additions & 0 deletions torchao/prototype/mx_formats/custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ def triton_f4_to_scaled_bf16(
size is currently assumed to be 32.
Output: a tensor of bfloat16 values, multiplied by the encoded scale
"""
s_e8m0 = s_e8m0.view(torch.uint8)
assert TORCH_VERSION_AT_LEAST_2_4, "unsupported"
new_shape = (*x.shape[:-1], x.shape[-1] * 2)
output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16)
Expand Down Expand Up @@ -859,6 +860,7 @@ def triton_f6_e2m3_to_scaled_bf16(
size is currently assumed to be 32.
Output: a tensor of bfloat16 values, multiplied by the encoded scale
"""
s_e8m0 = s_e8m0.view(torch.uint8)

packed_mx_block_size = 3 * mx_block_size // 4

Expand Down Expand Up @@ -900,6 +902,7 @@ def triton_f6_e3m2_to_scaled_bf16(
size is currently assumed to be 32.
Output: a tensor of bfloat16 values, multiplied by the encoded scale
"""
s_e8m0 = s_e8m0.view(torch.uint8)

packed_mx_block_size = 3 * mx_block_size // 4

Expand Down
9 changes: 8 additions & 1 deletion torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,15 @@ def to_mx(
else:
raise AssertionError("unsupported")

scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)

return scale_e8m0_biased, data_lp


# TODO(future PR): delete this function once casting from e8m0 to float works
# in triton + torchinductor
def get_fp_scale(scale_e8m0):
scale_e8m0 = scale_e8m0.view(torch.uint8)
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
# TODO(later): it would be nice if there was a way to do the 2^x operation
# in PyTorch without creating a tensor of twos
Expand Down Expand Up @@ -507,7 +512,9 @@ def __new__(
dtype=orig_dtype,
device=data_bits.device,
)
assert scale_e8m0_bits.dtype == torch.uint8, "unsupported"
assert (
scale_e8m0_bits.dtype == torch.float8_e8m0fnu
), f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}"
assert len(scale_e8m0_bits.shape) == 1, "unsupported"
assert data_bits.dtype in (
torch.float8_e4m3fn,
Expand Down
Loading