Skip to content

mx_formats: move inference to the quantize_ API #1971

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 53 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
af6ae2f
Update
vkuzo Mar 21, 2025
45120de
Update
vkuzo Mar 21, 2025
5527e72
Update
vkuzo Mar 21, 2025
478b9e1
Update
vkuzo Mar 21, 2025
571775d
Update
vkuzo Mar 21, 2025
fd30558
Update
vkuzo Mar 21, 2025
b0cd056
Update
vkuzo Mar 21, 2025
26b49fd
Update
vkuzo Mar 21, 2025
ba10a02
Update
vkuzo Mar 21, 2025
483cdfd
Update
vkuzo Mar 21, 2025
32005c9
Update
vkuzo Mar 21, 2025
e341c2e
Update
vkuzo Mar 24, 2025
7ecd79f
Update
vkuzo Mar 24, 2025
ca3c4cf
Update
vkuzo Mar 24, 2025
0de11cf
Update
vkuzo Mar 24, 2025
912e4dc
Update
vkuzo Mar 24, 2025
fb5662a
Update
vkuzo Mar 25, 2025
f245d64
Update
vkuzo Mar 26, 2025
9e5b8f8
Update
vkuzo Mar 26, 2025
e5bdecb
Update
vkuzo Mar 26, 2025
4c2ad8c
Update
vkuzo Mar 26, 2025
c1ceef1
Update
vkuzo Mar 26, 2025
65bfff0
Update
vkuzo Mar 26, 2025
0ff3a93
Update
vkuzo Mar 26, 2025
71a5548
Update
vkuzo Mar 26, 2025
0576d0d
Update
vkuzo Mar 26, 2025
f98453f
Update
vkuzo Mar 27, 2025
81dc214
Update
vkuzo Mar 27, 2025
5d60f24
Update
vkuzo Mar 27, 2025
a313055
Update
vkuzo Mar 27, 2025
798abfc
Update
vkuzo Mar 27, 2025
4933b66
Update
vkuzo Mar 27, 2025
d9e60c1
Update
vkuzo Mar 27, 2025
884f065
Update
vkuzo Mar 27, 2025
41b1f9d
Update
vkuzo Mar 27, 2025
5cc2755
Update
vkuzo Mar 27, 2025
af1f386
Update
vkuzo Mar 27, 2025
8691bd4
Update
vkuzo Mar 27, 2025
1a0993d
Update
vkuzo Mar 27, 2025
b053f97
Update
vkuzo Mar 27, 2025
9e335ce
Update
vkuzo Mar 27, 2025
87756f9
Update
vkuzo Mar 28, 2025
d0a0fd1
Update
vkuzo Mar 28, 2025
cf9dfe4
Update
vkuzo Mar 28, 2025
beafdd9
Update
vkuzo Mar 28, 2025
45abedf
Update
vkuzo Mar 28, 2025
af87eee
Update
vkuzo Mar 28, 2025
02d5065
Update
vkuzo Mar 28, 2025
d1bf83a
Update
vkuzo Mar 28, 2025
84c77d7
Update
vkuzo Mar 28, 2025
c603f09
Update
vkuzo Mar 28, 2025
42fb0e9
Update
vkuzo Mar 28, 2025
83e1e2e
Update
vkuzo Mar 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.nn as nn

from torchao.prototype.mx_formats.config import (
MXInferenceLinearConfig,
MXLinearConfig,
MXLinearRecipeName,
)
Expand All @@ -23,7 +24,6 @@
from torchao.prototype.mx_formats.mx_linear import (
MXInferenceLinear,
MXLinear,
swap_linear_with_mx_inference_linear,
)
from torchao.quantization import quantize_
from torchao.quantization.utils import compute_error
Expand Down Expand Up @@ -294,8 +294,8 @@ def test_inference_linear(elem_dtype, bias, input_shape):
m = nn.Sequential(nn.Linear(4, 8, bias=bias, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
swap_linear_with_mx_inference_linear(m_mx, config=config)
config = MXInferenceLinearConfig(block_size=4, elem_dtype=elem_dtype)
quantize_(m_mx, config=config)

x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
y_ref = m(x)
Expand All @@ -319,8 +319,8 @@ def test_inference_compile_simple(elem_dtype):
m = nn.Sequential(nn.Linear(4, 8, bias=False, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
swap_linear_with_mx_inference_linear(m_mx, config=config)
config = MXInferenceLinearConfig(block_size=4, elem_dtype=elem_dtype)
quantize_(m_mx, config=config)
m_mx = torch.compile(m_mx, fullgraph="true")

x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16)
Expand All @@ -346,7 +346,8 @@ def test_filter_fn():
assert type(m1[0]) == MXLinear
assert type(m1[1]) == torch.nn.Linear

swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501
config2 = MXInferenceLinearConfig(block_size=32)
quantize_(m2, config=config2, filter_fn=filter_fn) # noqa: E501
assert type(m2[0]) == MXInferenceLinear
assert type(m2[1]) == torch.nn.Linear

Expand All @@ -362,8 +363,8 @@ def test_training_print_str():

def test_inference_print_str():
m = nn.Sequential(nn.Linear(32, 32))
config = MXLinearConfig()
swap_linear_with_mx_inference_linear(m, config=config)
config = MXInferenceLinearConfig()
quantize_(m, config=config)
s = str(m)
assert "bl_sz=32" in s
assert "kernel=emulated" in s
13 changes: 9 additions & 4 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,17 @@ This is a module to do MX inference, weights are in MX and matmul is in high pre

```python
import torch
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear
from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.quantization import quantize_
from torchao.prototype.mx_formats import MXInferenceLinearConfig, MXGemmKernelChoice

m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
swap_linear_with_mx_inference_linear(m, config=config)
gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
config = MXInferenceLinearConfig(
elem_dtype=torch.float8_e4m3fn,
block_size=32,
gemm_kernel_choice=gemm_kernel_choice,
)
quantize_(m, config=config)

# do inference (not shown)
```
Expand Down
4 changes: 3 additions & 1 deletion torchao/prototype/mx_formats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
MXInferenceLinearConfig,
MXLinearConfig,
MXLinearRecipeName,
)
Expand All @@ -9,7 +10,8 @@
import torchao.prototype.mx_formats.mx_linear # noqa: F401

__all__ = [
"MXLinearConfig",
"MXGemmKernelChoice",
"MXInferenceLinearConfig",
"MXLinearConfig",
"MXLinearRecipeName",
]
123 changes: 78 additions & 45 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from torchao.core.config import AOBaseConfig
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
DTYPE_TO_SHORT_STR,
SUPPORTED_ELEM_DTYPES,
)
Expand Down Expand Up @@ -41,6 +43,31 @@ class MXLinearRecipeName(Enum):
MXFP4_CUTLASS = "mxfp4_cutlass"


def _validate_elem_dtype(elem_dtype):
assert (
elem_dtype in SUPPORTED_ELEM_DTYPES
), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {elem_dtype}"


def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
if gemm_kernel_choice == MXGemmKernelChoice.CUTLASS:
assert (
block_size == 32
), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {block_size}"
valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4]
assert (
elem_dtype in valid_dtypes
), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
assert (
block_size == 32
), f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
valid_dtypes = [torch.float8_e4m3fn]
assert (
elem_dtype in valid_dtypes
), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"


@dataclass
class MXLinearConfig(AOBaseConfig):
# block size for scaling, default is 32 to match
Expand Down Expand Up @@ -68,53 +95,17 @@ class MXLinearConfig(AOBaseConfig):
# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False

# If True, packs 4xFP6 into 3xuint8 containers for inference, using custom triton
# kernels (fused unpack/dequantize). Training not currently supported.
pack_fp6 = True if hasattr(torch.library, "custom_op") else False

def __post_init__(self):
# validate elem_dtype and its overrides
assert (
self.elem_dtype in SUPPORTED_ELEM_DTYPES
), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
_validate_elem_dtype(self.elem_dtype)
_validate_gemm_kernel_choice(
self.gemm_kernel_choice, self.block_size, self.elem_dtype
)
if self.elem_dtype_weight_override is not None:
assert (
self.elem_dtype_weight_override in SUPPORTED_ELEM_DTYPES
), f"elem_dtype_weight_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
_validate_elem_dtype(self.elem_dtype_weight_override)
assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported"
if self.elem_dtype_grad_output_override is not None:
assert (
self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES
), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"

# validate that block size and elem_dtype matches kernel choice
if self.gemm_kernel_choice == MXGemmKernelChoice.CUTLASS:
assert (
self.block_size == 32
), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {self.block_size}"
valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4]
assert (
self.elem_dtype in valid_dtypes
), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}"
assert (
self.elem_dtype_weight_override is None
), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels"
assert (
self.elem_dtype_grad_output_override is None
), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels"
elif self.gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
assert (
self.block_size == 32
), f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {self.block_size}"
valid_dtypes = [torch.float8_e4m3fn]
assert (
self.elem_dtype in valid_dtypes
), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}"
assert (
self.elem_dtype_weight_override is None
), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels"
assert (
self.elem_dtype_grad_output_override is None
), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels"
_validate_elem_dtype(self.elem_dtype_grad_output_override)
assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported"

@staticmethod
def from_recipe_name(
Expand Down Expand Up @@ -162,5 +153,47 @@ def short_str(self) -> str:
s += ", use_fp8_dim1_cast_triton_kernel=True"
if self.use_fp4_custom_triton_dequant_kernel:
s += ", use_fp4_custom_triton_dequant_kernel=True"
# TODO(future PR): split training from inference and add fp6 here
return s


@dataclass
class MXInferenceLinearConfig(AOBaseConfig):
# block size for scaling, default is 32 to match
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
# section 5.2
block_size: int = 32

# element dtype, used for activations, weights and gradients
elem_dtype: Any = torch.float8_e4m3fn
# TODO(future PR): support different elem_dtype for activations vs weights

# defines the gemm kernel choice, if the chosen kernel is not supported
# on the given hardware an exception will be thrown
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED

# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False

# If True, packs 4xFP6 into 3xuint8 containers for inference, using custom triton
# kernels (fused unpack/dequantize).
pack_fp6: bool = True

def __post_init__(self):
_validate_elem_dtype(self.elem_dtype)
_validate_gemm_kernel_choice(
self.gemm_kernel_choice, self.block_size, self.elem_dtype
)

def short_str(self) -> str:
"""
Returns a concise representation of the current config.
"""
s = f"bl_sz={self.block_size}, lp_dtype={DTYPE_TO_SHORT_STR[self.elem_dtype]}"
s += f", kernel={self.gemm_kernel_choice.value}"
if self.use_fp4_custom_triton_dequant_kernel:
s += ", use_fp4_custom_triton_dequant_kernel=True"
if self.elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2) and self.pack_fp6:
s += ", pack_fp6=True"
return s

# TODO(future PR): add a recipe to config API for inference
56 changes: 10 additions & 46 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
import torch
import torch.nn.functional as F

from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig
from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
MXInferenceLinearConfig,
MXLinearConfig,
)
from torchao.prototype.mx_formats.custom_cast import triton_to_mxfp8_dim1
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.quantization.transform_module import (
Expand Down Expand Up @@ -234,7 +238,7 @@ class MXInferenceLinear(torch.nn.Linear):
def from_float(
cls,
mod,
config: Optional[MXLinearConfig] = MXLinearConfig(),
config: Optional[MXInferenceLinearConfig] = MXInferenceLinearConfig(),
):
with torch.device("meta"):
super_kwargs = {
Expand Down Expand Up @@ -267,53 +271,13 @@ def extra_repr(self):
return s


def replace_with_custom_fn_if_matches_filter(
model, replacement_fn, filter_fn, cur_fqn=""
) -> None:
"""
For each `child` in `model`, replaces it with `replacement_fn(child)`
if `filter_fn(child)` is `True`
"""
name_to_child = dict(model.named_children())
for name, child in name_to_child.items():
if cur_fqn == "":
new_fqn = name
else:
new_fqn = f"{cur_fqn}.{name}"
if filter_fn(child, new_fqn):
new_child = replacement_fn(child)
setattr(model, name, new_child)
else:
replace_with_custom_fn_if_matches_filter(
child, replacement_fn, filter_fn, new_fqn
)


def _is_linear(mod, fqn):
return isinstance(mod, torch.nn.Linear)


@register_quantize_module_handler(MXLinearConfig)
def _mx_linear_transform(module: torch.nn.Module, config: MXLinearConfig):
return MXLinear.from_float(module, config=config)


def swap_linear_with_mx_inference_linear(
model,
*,
config: Optional[MXLinearConfig] = None,
filter_fn=None,
@register_quantize_module_handler(MXInferenceLinearConfig)
def _mx_inference_linear_transform(
module: torch.nn.Module, config: MXInferenceLinearConfig
):
if filter_fn is None:
combined_filter_fn = _is_linear
else:

def __fn(mod, fqn):
return _is_linear(mod, fqn) and filter_fn(mod, fqn)

combined_filter_fn = __fn
replace_with_custom_fn_if_matches_filter(
model,
lambda mod: MXInferenceLinear.from_float(mod, config=config),
combined_filter_fn,
)
return MXInferenceLinear.from_float(module, config=config)
Loading