Skip to content

Make custom FPx dtype conversion easier to use #354

Closed
@gau-nernst

Description

@gau-nernst

Referring this https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/custom_cast.py

Although it was originally meant for MX dtypes only (FP4 E2M1, FP6 E2M3, FP6 E3M2), expanding its functionality to support any custom FPx dtype would be useful for developing and experimenting with custom FPx kernels.

Case in point, FP6-LLM upstream added support for FP5 E2M2 (https://github.com/usyd-fsalab/fp6_llm). This is what I need to write to support FP32->FP5 E2M2.

# define constants for F32 <-> F5_E2M2
F5_E2M2_MAX = 7.0  # (2 ** (0b11 - 0b01)) * (1 + 0.5 + 0.25)
F5_E2M2_MIN_NORMAL = 1.0  # (2 ** (0b01 - 0b01))
EBITS_F5_E2M2 = 2
MBITS_F5_E2M2 = 2
F5_E2M2_EXP_BIAS = 0b01
F5_E2M2_MAX_INT = (1 << 4) - 1
SIGN_MASK_F5_E2M2 = 1 << 4

MAGIC_ADDER_F5_E2M2 = (1 << (MBITS_F32 - EBITS_F5_E2M2)) - 1

DENORM_F32TOF5_E2M2_EXP = (
    # exp bias conversion between formats
    (F32_EXP_BIAS - F5_E2M2_EXP_BIAS)
    # mantissa length difference between formats
    + (MBITS_F32 - MBITS_F5_E2M2)
    # add one to encoded exponent for denormalized numbers
    + 1
)
DENORM_F32TOF5_E2M2_MASK_INT = DENORM_F32TOF5_E2M2_EXP << MBITS_F32
# reinterpret int32 as float32 in Python
# see https://stackoverflow.com/a/34446112/1058521
DENORM_F32TOF5_E2M2_MASK_FLOAT = struct.unpack("!f", struct.pack("!I", DENORM_F32TOF5_E2M2_MASK_INT))[0]


def f32_to_f5_e2m2_unpacked(x: Tensor):
    return _f32_to_f4_or_f6_unpacked(
        x,
        F5_E2M2_MAX,
        F5_E2M2_MIN_NORMAL,
        DENORM_F32TOF5_E2M2_MASK_FLOAT,
        DENORM_F32TOF5_E2M2_MASK_INT,
        EBITS_F5_E2M2,
        MBITS_F5_E2M2,
        F5_E2M2_EXP_BIAS,
        MAGIC_ADDER_F5_E2M2,
        F5_E2M2_MAX_INT,
        SIGN_MASK_F5_E2M2,
    )

Ideally, we shouldn't need to calculate all the constants by ourselves, only provide number of E and M bits, and these constants should be calculated within the function (or cache them somewhere, though I think re-calculating these constants shouldn't take much time).

The other direction (FPx->FP32) is a bit trickier when handling denormal FPx, but should still be possible to make it more generic.

Proposed changes

  • Change _f32_to_f4_or_f6_unpacked() and _f4_or_f6_unpacked_to_f32() to _f32_to_fpx_unpacked(x, n_ebits, n_mbits) and _fpx_unpacked_to_f32(x, n_ebits, n_mbits) (packed format is out of scope, should be handled separately for each case)
  • (Maybe) Move non-mx specific stuff from custom_cast.py to an upper level e.g. prototype/fp_cast_utils.py (e.g. functions for packed fp4, custom triton kernels should stay in custom_cast.py)

Tagging @vkuzo and @msaroufim for discussion and opinion.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions