Skip to content

print MX config when printing MXLinear and MXInferenceLinear #1947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Mar 27, 2025
18 changes: 18 additions & 0 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,21 @@ def test_filter_fn():
swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501
assert type(m2[0]) == MXInferenceLinear
assert type(m2[1]) == torch.nn.Linear


def test_training_print_str():
m = nn.Sequential(nn.Linear(32, 32))
config = MXLinearConfig()
swap_linear_with_mx_linear(m, config=config)
s = str(m)
assert "bl_sz=32" in s
assert "kernel=emulated" in s


def test_inference_print_str():
m = nn.Sequential(nn.Linear(32, 32))
config = MXLinearConfig()
swap_linear_with_mx_inference_linear(m, config=config)
s = str(m)
assert "bl_sz=32" in s
assert "kernel=emulated" in s
20 changes: 20 additions & 0 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_TO_SHORT_STR,
SUPPORTED_ELEM_DTYPES,
)

Expand Down Expand Up @@ -143,3 +144,22 @@ def from_recipe_name(
)
else:
raise AssertionError(f"unknown recipe_name {recipe_name}")

def short_str(self) -> str:
"""
Returns a concise representation of the current config.
"""
s = f"bl_sz={self.block_size}, lp_dtype={DTYPE_TO_SHORT_STR[self.elem_dtype]}"
if self.elem_dtype_weight_override is not None:
s += (
f", lp_w_override={DTYPE_TO_SHORT_STR[self.elem_dtype_weight_override]}"
)
if self.elem_dtype_grad_output_override is not None:
s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
s += f", kernel={self.gemm_kernel_choice.value}"
if self.use_fp8_dim1_cast_triton_kernel:
s += ", use_fp8_dim1_cast_triton_kernel=True"
if self.use_fp4_custom_triton_dequant_kernel:
s += ", use_fp4_custom_triton_dequant_kernel=True"
# TODO(future PR): split training from inference and add fp6 here
return s
8 changes: 8 additions & 0 deletions torchao/prototype/mx_formats/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
DTYPE_FP4,
]

DTYPE_TO_SHORT_STR = {
torch.float8_e4m3fn: "f8e4m3",
torch.float8_e5m2: "f8e5m2",
DTYPE_FP6_E2M3: "f6e2m3",
DTYPE_FP6_E3M2: "f6e3m2",
DTYPE_FP4: "f4e2m1",
}

F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
F8E5M2_MAX = torch.finfo(torch.float8_e5m2).max # 57344.0

Expand Down
8 changes: 8 additions & 0 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def forward(self, x):
y = y + self.bias
return y

def extra_repr(self):
s = f"{super().extra_repr()}, {self.config.short_str()}"
return s


class MXInferenceLinear(torch.nn.Linear):
"""
Expand Down Expand Up @@ -255,6 +259,10 @@ def forward(self, x):
y = F.linear(x, w_hp, self.bias)
return y

def extra_repr(self):
s = f"{super().extra_repr()}, {self.config.short_str()}"
return s


def replace_with_custom_fn_if_matches_filter(
model, replacement_fn, filter_fn, cur_fqn=""
Expand Down
Loading