Skip to content

Commit f44d1f7

Browse files
authored
MXFP4 and MXFP8 loading support (#832)
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent 81c1f52 commit f44d1f7

File tree

11 files changed

+619
-10
lines changed

11 files changed

+619
-10
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from auto_round.experimental.qmodules.mx import MXFP4QuantLinear, MXFP8QuantLinear
16+
from auto_round.experimental.qmodules.fp8_static import WeightFP8ActFP8StaticQuantLinear
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
17+
#
18+
# Licensed under the Apache License, Version 2.0 (the "License");
19+
# you may not use this file except in compliance with the License.
20+
# You may obtain a copy of the License at
21+
#
22+
# http://www.apache.org/licenses/LICENSE-2.0
23+
#
24+
# Unless required by applicable law or agreed to in writing,
25+
# software distributed under the License is distributed on an "AS IS" BASIS,
26+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27+
# See the License for the specific language governing permissions and
28+
# limitations under the License.
29+
30+
from typing import Optional
31+
32+
import torch
33+
34+
kE2M1ToFloat = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32)
35+
36+
37+
def unpack_fp4_from_uint8(
38+
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
39+
) -> torch.Tensor:
40+
"""
41+
Unpacks uint8 values into FP4. Each uint8 contains two FP4 values
42+
(low nibble first). The 4-bit indices are mapped to FP4 values using kE2M1ToFloat.
43+
"""
44+
if a.device.type == "cuda":
45+
return _unpack_fp4_from_uint8_cuda(a, m, n, dtype)
46+
else:
47+
return _unpack_fp4_from_uint8_cpu(a, m, n, dtype)
48+
49+
50+
@torch.compiler.disable()
51+
def _unpack_fp4_from_uint8_cpu(
52+
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
53+
) -> torch.Tensor:
54+
return _unpack_fp4_from_uint8(a, m, n, dtype)
55+
56+
57+
@torch.compile(fullgraph=True, dynamic=True)
58+
def _unpack_fp4_from_uint8_cuda(
59+
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
60+
) -> torch.Tensor:
61+
return _unpack_fp4_from_uint8(a, m, n, dtype)
62+
63+
64+
# reference: : https://github.com/vllm-project/vllm/pull/16362
65+
def _unpack_fp4_from_uint8(
66+
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
67+
) -> torch.Tensor:
68+
"""
69+
Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values
70+
(i.e. first four bits correspond to one fp4 value, last four correspond to a
71+
consecutive fp4 value). The bits represent an index, which are mapped to an fp4
72+
value.
73+
74+
:param a: tensor to unpack
75+
:param m: original dim 0 size of the unpacked tensor
76+
:param n: original dim 1 size of the unpacked tensor
77+
:param dtype: dense dtype to cast the unpacked tensor to
78+
"""
79+
assert a.dtype == torch.uint8, f"expected uint8, got {a.dtype}"
80+
81+
# Vectorized nibble processing
82+
a_flat = a.flatten()
83+
high = (a_flat & 0xF0) >> 4 # Upper nibbles
84+
low = a_flat & 0x0F # Lower nibbles
85+
86+
# Combine nibbles for batch processing
87+
combined = torch.stack((low, high), dim=1).flatten()
88+
89+
# Vectorized sign and magnitude extraction
90+
signs = (combined & 0x08).to(torch.bool) # Sign bits
91+
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
92+
93+
# Device-aware lookup and sign application
94+
kE2M1 = kE2M1ToFloat.to(device=a.device)
95+
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
96+
97+
# Reshape to final form
98+
return values.reshape(m, n).to(dtype=dtype)
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from typing import Optional, Union
17+
18+
import torch
19+
20+
from auto_round.data_type.utils import get_quant_func
21+
from auto_round.experimental.qmodules.base import QModuleBase
22+
from auto_round.experimental.qmodules.fp4_utils import unpack_fp4_from_uint8
23+
from auto_round.logger import logger
24+
from auto_round.schemes import QuantizationScheme
25+
26+
__all__ = ["MXFP4QuantLinear", "MXFP8QuantLinear"]
27+
28+
SUPPORTED_HIGHER_DTYPE = [torch.bfloat16, torch.float16, torch.float32]
29+
E8M0_EXPONENT_BIAS = 127
30+
31+
32+
def _mx_qdq(tensor: torch.Tensor, config: QuantizationScheme):
33+
qdq_func, _ = get_quant_func(dtype=config.act_data_type, bits=config.act_bits, sym=True)
34+
qdq_tensor, shared_exp, _ = qdq_func(tensor=tensor, bits=config.act_bits, group_size=config.act_group_size)
35+
return qdq_tensor
36+
37+
38+
# https://github.com/pytorch/ao/blob/994a4ba6c869854fcaa6ca7e118fcbd75e6c28cc/torchao/prototype/mx_formats/mx_tensor.py#L337
39+
def get_fp_scale(scale_e8m0):
40+
scale_e8m0 = scale_e8m0.view(torch.uint8)
41+
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
42+
two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device)
43+
# TODO(later): handle this for float16 if we decide to support float16
44+
s_fp = torch.pow(two, s_offset)
45+
46+
return s_fp
47+
48+
49+
class MXQuantLinearBase(QModuleBase):
50+
"""
51+
Base class for quantized linear layers using MXFP quantization schemes.
52+
"""
53+
54+
def __init__(
55+
self,
56+
in_features,
57+
out_features,
58+
config: QuantizationScheme,
59+
weight: Optional[torch.Tensor] = None,
60+
weight_scale: Optional[torch.Tensor] = None,
61+
bias: Union[torch.Tensor, bool, None] = None,
62+
dtype=torch.bfloat16,
63+
):
64+
super().__init__()
65+
self.in_features = in_features
66+
self.out_features = out_features
67+
self.group_size = 32
68+
self.config = config
69+
self.dtype = dtype
70+
self.pre_dequantized = False
71+
72+
# Validate dtype
73+
assert (
74+
dtype in SUPPORTED_HIGHER_DTYPE
75+
), f"Expected dtype to be one of {SUPPORTED_HIGHER_DTYPE}, but got {dtype}."
76+
77+
# Initialize weights
78+
init_weight = self.initialize_weights(weight)
79+
self.register_buffer(self.weight_name, init_weight)
80+
81+
# Initialize bias
82+
if bias is not None:
83+
if isinstance(bias, bool):
84+
bias = torch.zeros((out_features,), dtype=dtype)
85+
self.bias = torch.nn.Parameter(bias, requires_grad=False)
86+
else:
87+
self.register_parameter("bias", None)
88+
89+
# Initialize weight scale
90+
init_weight_scale = (
91+
torch.empty((out_features, in_features // self.group_size), dtype=torch.uint8)
92+
if weight_scale is None
93+
else weight_scale
94+
)
95+
self.register_buffer("weight_scale", init_weight_scale)
96+
97+
def initialize_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor:
98+
"""
99+
Initialize weights. This method should be overridden by subclasses.
100+
"""
101+
raise NotImplementedError("Subclasses must implement `initialize_weights`.")
102+
103+
@classmethod
104+
def get_min_capability(cls) -> int:
105+
"""
106+
Get minimum device capability.
107+
"""
108+
logger.warning_once("MXFP quantization is still in experimental stage, the inference speed might be slow.")
109+
return 0
110+
111+
def dequant_mx_tensor(
112+
self, packed_data: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype = torch.float32
113+
) -> torch.Tensor:
114+
scale_float = self._get_float_scale(scale).to(target_dtype)
115+
unpacked_data = self.unpack_data(packed_data)
116+
original_shape = unpacked_data.shape
117+
unpacked_data = unpacked_data.reshape(-1, self.group_size)
118+
scale_float = scale_float.reshape(-1, 1)
119+
data_float = unpacked_data.to(target_dtype)
120+
data_dequant = data_float * scale_float
121+
data_dequant = data_dequant.reshape(original_shape)
122+
return data_dequant
123+
124+
def dequant_weight_online(self):
125+
if self.pre_dequantized:
126+
return self.weight
127+
dq_weight = self.dequant_mx_tensor(self.weight, self.weight_scale)
128+
return dq_weight
129+
130+
def pre_dequantize(self):
131+
if self.pre_dequantized:
132+
return
133+
dequant_weight = self.dequant_weight_online()
134+
delattr(self, self.weight_name)
135+
del self.weight_scale
136+
self.weight = torch.nn.Parameter(dequant_weight, requires_grad=False)
137+
self.pre_dequantized = True
138+
139+
def qdq_input(self, activation: torch.Tensor):
140+
return _mx_qdq(activation, self.config)
141+
142+
@classmethod
143+
def _get_float_scale(cls, scale_e8m0: torch.Tensor) -> torch.Tensor:
144+
return get_fp_scale(scale_e8m0)
145+
146+
@torch.inference_mode()
147+
def forward(self, input: torch.Tensor) -> torch.Tensor:
148+
qdq_input = self.qdq_input(input)
149+
qdq_weight = self.dequant_weight_online()
150+
qdq_weight = qdq_weight.to(qdq_input.dtype)
151+
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
152+
return out
153+
154+
@classmethod
155+
def from_original(cls, config: Optional[QuantizationScheme], original_layer: torch.nn.Linear):
156+
"""
157+
Create an `MXQuantLinear` layer from an original linear layer.
158+
"""
159+
logger.warning_once("MXFP quantization is still in experimental stage, the inference speed might be slow.")
160+
qdq_linear = cls(
161+
in_features=original_layer.in_features,
162+
out_features=original_layer.out_features,
163+
config=config,
164+
bias=original_layer.bias,
165+
dtype=original_layer.weight.dtype,
166+
)
167+
return qdq_linear
168+
169+
170+
class MXFP4QuantLinear(MXQuantLinearBase):
171+
"""
172+
Quantized linear layer using the MXFP4 quantization scheme.
173+
"""
174+
175+
def __init__(self, *args, **kwargs):
176+
self.weight_name = "weight_packed"
177+
super().__init__(*args, **kwargs)
178+
179+
def initialize_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor:
180+
weight_dtype = torch.uint8
181+
weight_in_features = self.in_features // 2
182+
return torch.zeros((self.out_features, weight_in_features), dtype=weight_dtype) if weight is None else weight
183+
184+
def dequant_weight_online(self) -> torch.Tensor:
185+
if self.pre_dequantized:
186+
return self.weight
187+
dq_weight = self.dequant_mx_tensor(self.weight_packed, self.weight_scale)
188+
return dq_weight
189+
190+
def unpack_data(self, packed_data: torch.Tensor) -> torch.Tensor:
191+
m, half_n = packed_data.shape
192+
unpacked_data = unpack_fp4_from_uint8(packed_data, m, half_n * 2, dtype=self.dtype)
193+
return unpacked_data
194+
195+
196+
class MXFP8QuantLinear(MXQuantLinearBase):
197+
"""
198+
Quantized linear layer using the MXFP8 quantization scheme.
199+
"""
200+
201+
def __init__(self, *args, **kwargs):
202+
self.weight_name = "weight"
203+
super().__init__(*args, **kwargs)
204+
205+
def initialize_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor:
206+
weight_dtype = torch.float8_e4m3fn
207+
weight_in_features = self.in_features
208+
return torch.zeros((self.out_features, weight_in_features), dtype=weight_dtype) if weight is None else weight
209+
210+
def unpack_data(self, packed_data: torch.Tensor) -> torch.Tensor:
211+
return packed_data.to(self.dtype)

auto_round/export/export_to_autoround/export.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class AutoRoundFormat(str, Enum):
4949
# Weight: FP8, per-channel, may be extended to per-tensor in future
5050
# Activation: FP8, per-tensor
5151
FP8_STATIC = "fp8_static"
52+
MXFP8 = "mxfp8"
53+
MXFP4 = "mxfp4"
5254
FP8 = "fp8"
5355

5456

auto_round/export/export_to_autoround/qlinear_fp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,17 @@ def __init__(
9090
self.act_bits = kwargs.get("act_bits", None)
9191

9292
weight_name = "weight" if self.bits == 8 and self.is_mx else "weight_packed"
93+
weight_infeatures = infeatures if self.bits == 8 else infeatures // 2
94+
weight_dtype = torch.float8_e4m3fn if self.bits == 8 else torch.uint8
9395
## TODO check the dtype of weight_packed and weight_scale
9496
self.register_buffer(
9597
weight_name,
96-
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32),
98+
torch.zeros((outfeatures, weight_infeatures), dtype=weight_dtype),
9799
)
98100
self.register_buffer(
99101
"weight_scale",
100102
torch.zeros(
101-
(math.ceil(infeatures / self.group_size), outfeatures),
103+
(outfeatures, math.ceil(infeatures / self.group_size)),
102104
dtype=torch.float16, ## TODO update to correct scale dtype for different bits
103105
),
104106
)
@@ -156,11 +158,13 @@ def pack(self, linear, scales, zeros=None, g_idx=None, global_scale=None, input_
156158
final_scale = (scales + E8M0_EXPONENT_BIAS).clamp(0, E8M0_EXPONENT_NAN_VAL).to(torch.uint8)
157159
else:
158160
final_scale = scales.to(torch.float8_e4m3fn)
161+
159162
self.weight_scale = final_scale
160163
# self.weight = get_compressed_weight(scaled_tensor, self.bits, self.data_type) ## TODO
161164
if self.bits == 8:
162165
compress_dtype = torch.float8_e4m3fn
163166
self.weight = scaled_tensor.to(compress_dtype)
167+
164168
else:
165169
compress_dtype = torch.uint8
166170
self.weight_packed = pack_fp4_to_uint8(scaled_tensor)

0 commit comments

Comments
 (0)