Skip to content
18 changes: 18 additions & 0 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,21 @@ def test_filter_fn():
swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501
assert type(m2[0]) == MXInferenceLinear
assert type(m2[1]) == torch.nn.Linear


def test_training_print_str():
m = nn.Sequential(nn.Linear(32, 32))
config = MXLinearConfig()
swap_linear_with_mx_linear(m, config=config)
s = str(m)
assert "bl_sz=32" in s
assert "kernel=emulated" in s


def test_inference_print_str():
m = nn.Sequential(nn.Linear(32, 32))
config = MXLinearConfig()
swap_linear_with_mx_inference_linear(m, config=config)
s = str(m)
assert "bl_sz=32" in s
assert "kernel=emulated" in s
20 changes: 20 additions & 0 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_TO_SHORT_STR,
SUPPORTED_ELEM_DTYPES,
)

Expand Down Expand Up @@ -143,3 +144,22 @@ def from_recipe_name(
)
else:
raise AssertionError(f"unknown recipe_name {recipe_name}")

def short_str(self) -> str:
"""
Returns a concise representation of the current config.
"""
s = f"bl_sz={self.block_size}, lp_dtype={DTYPE_TO_SHORT_STR[self.elem_dtype]}"
if self.elem_dtype_weight_override is not None:
s += (
f", lp_w_override={DTYPE_TO_SHORT_STR[self.elem_dtype_weight_override]}"
)
if self.elem_dtype_grad_output_override is not None:
s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
s += f", kernel={self.gemm_kernel_choice.value}"
if self.use_fp8_dim1_cast_triton_kernel:
s += ", use_fp8_dim1_cast_triton_kernel=True"
if self.use_fp4_custom_triton_dequant_kernel:
s += ", use_fp4_custom_triton_dequant_kernel=True"
# TODO(future PR): split training from inference and add fp6 here
return s
8 changes: 8 additions & 0 deletions torchao/prototype/mx_formats/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
DTYPE_FP4,
]

DTYPE_TO_SHORT_STR = {
torch.float8_e4m3fn: "f8e4m3",
torch.float8_e5m2: "f8e5m2",
DTYPE_FP6_E2M3: "f6e2m3",
DTYPE_FP6_E3M2: "f6e3m2",
DTYPE_FP4: "f4e2m1",
}

F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
F8E5M2_MAX = torch.finfo(torch.float8_e5m2).max # 57344.0

Expand Down
8 changes: 8 additions & 0 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def forward(self, x):
y = y + self.bias
return y

def extra_repr(self):
s = f"{super().extra_repr()}, {self.config.short_str()}"
return s


class MXInferenceLinear(torch.nn.Linear):
"""
Expand Down Expand Up @@ -255,6 +259,10 @@ def forward(self, x):
y = F.linear(x, w_hp, self.bias)
return y

def extra_repr(self):
s = f"{super().extra_repr()}, {self.config.short_str()}"
return s


def replace_with_custom_fn_if_matches_filter(
model, replacement_fn, filter_fn, cur_fqn=""
Expand Down
Loading