Description
Referring this https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/custom_cast.py
Although it was originally meant for MX dtypes only (FP4 E2M1, FP6 E2M3, FP6 E3M2), expanding its functionality to support any custom FPx dtype would be useful for developing and experimenting with custom FPx kernels.
Case in point, FP6-LLM upstream added support for FP5 E2M2 (https://github.com/usyd-fsalab/fp6_llm). This is what I need to write to support FP32->FP5 E2M2.
# define constants for F32 <-> F5_E2M2
F5_E2M2_MAX = 7.0 # (2 ** (0b11 - 0b01)) * (1 + 0.5 + 0.25)
F5_E2M2_MIN_NORMAL = 1.0 # (2 ** (0b01 - 0b01))
EBITS_F5_E2M2 = 2
MBITS_F5_E2M2 = 2
F5_E2M2_EXP_BIAS = 0b01
F5_E2M2_MAX_INT = (1 << 4) - 1
SIGN_MASK_F5_E2M2 = 1 << 4
MAGIC_ADDER_F5_E2M2 = (1 << (MBITS_F32 - EBITS_F5_E2M2)) - 1
DENORM_F32TOF5_E2M2_EXP = (
# exp bias conversion between formats
(F32_EXP_BIAS - F5_E2M2_EXP_BIAS)
# mantissa length difference between formats
+ (MBITS_F32 - MBITS_F5_E2M2)
# add one to encoded exponent for denormalized numbers
+ 1
)
DENORM_F32TOF5_E2M2_MASK_INT = DENORM_F32TOF5_E2M2_EXP << MBITS_F32
# reinterpret int32 as float32 in Python
# see https://stackoverflow.com/a/34446112/1058521
DENORM_F32TOF5_E2M2_MASK_FLOAT = struct.unpack("!f", struct.pack("!I", DENORM_F32TOF5_E2M2_MASK_INT))[0]
def f32_to_f5_e2m2_unpacked(x: Tensor):
return _f32_to_f4_or_f6_unpacked(
x,
F5_E2M2_MAX,
F5_E2M2_MIN_NORMAL,
DENORM_F32TOF5_E2M2_MASK_FLOAT,
DENORM_F32TOF5_E2M2_MASK_INT,
EBITS_F5_E2M2,
MBITS_F5_E2M2,
F5_E2M2_EXP_BIAS,
MAGIC_ADDER_F5_E2M2,
F5_E2M2_MAX_INT,
SIGN_MASK_F5_E2M2,
)
Ideally, we shouldn't need to calculate all the constants by ourselves, only provide number of E and M bits, and these constants should be calculated within the function (or cache them somewhere, though I think re-calculating these constants shouldn't take much time).
The other direction (FPx->FP32) is a bit trickier when handling denormal FPx, but should still be possible to make it more generic.
Proposed changes
- Change
_f32_to_f4_or_f6_unpacked()
and_f4_or_f6_unpacked_to_f32()
to_f32_to_fpx_unpacked(x, n_ebits, n_mbits)
and_fpx_unpacked_to_f32(x, n_ebits, n_mbits)
(packed format is out of scope, should be handled separately for each case) - (Maybe) Move non-mx specific stuff from
custom_cast.py
to an upper level e.g.prototype/fp_cast_utils.py
(e.g. functions for packed fp4, custom triton kernels should stay incustom_cast.py
)
Tagging @vkuzo and @msaroufim for discussion and opinion.