|
| 1 | +# Copyright (c) 2025 Intel Corporation |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +from typing import Optional, Union |
| 17 | + |
| 18 | +import torch |
| 19 | + |
| 20 | +from auto_round.data_type.utils import get_quant_func |
| 21 | +from auto_round.experimental.qmodules.base import QModuleBase |
| 22 | +from auto_round.experimental.qmodules.fp4_utils import unpack_fp4_from_uint8 |
| 23 | +from auto_round.logger import logger |
| 24 | +from auto_round.schemes import QuantizationScheme |
| 25 | + |
| 26 | +__all__ = ["MXFP4QuantLinear", "MXFP8QuantLinear"] |
| 27 | + |
| 28 | +SUPPORTED_HIGHER_DTYPE = [torch.bfloat16, torch.float16, torch.float32] |
| 29 | +E8M0_EXPONENT_BIAS = 127 |
| 30 | + |
| 31 | + |
| 32 | +def _mx_qdq(tensor: torch.Tensor, config: QuantizationScheme): |
| 33 | + qdq_func, _ = get_quant_func(dtype=config.act_data_type, bits=config.act_bits, sym=True) |
| 34 | + qdq_tensor, shared_exp, _ = qdq_func(tensor=tensor, bits=config.act_bits, group_size=config.act_group_size) |
| 35 | + return qdq_tensor |
| 36 | + |
| 37 | + |
| 38 | +# https://github.com/pytorch/ao/blob/994a4ba6c869854fcaa6ca7e118fcbd75e6c28cc/torchao/prototype/mx_formats/mx_tensor.py#L337 |
| 39 | +def get_fp_scale(scale_e8m0): |
| 40 | + scale_e8m0 = scale_e8m0.view(torch.uint8) |
| 41 | + s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS |
| 42 | + two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device) |
| 43 | + # TODO(later): handle this for float16 if we decide to support float16 |
| 44 | + s_fp = torch.pow(two, s_offset) |
| 45 | + |
| 46 | + return s_fp |
| 47 | + |
| 48 | + |
| 49 | +class MXQuantLinearBase(QModuleBase): |
| 50 | + """ |
| 51 | + Base class for quantized linear layers using MXFP quantization schemes. |
| 52 | + """ |
| 53 | + |
| 54 | + def __init__( |
| 55 | + self, |
| 56 | + in_features, |
| 57 | + out_features, |
| 58 | + config: QuantizationScheme, |
| 59 | + weight: Optional[torch.Tensor] = None, |
| 60 | + weight_scale: Optional[torch.Tensor] = None, |
| 61 | + bias: Union[torch.Tensor, bool, None] = None, |
| 62 | + dtype=torch.bfloat16, |
| 63 | + ): |
| 64 | + super().__init__() |
| 65 | + self.in_features = in_features |
| 66 | + self.out_features = out_features |
| 67 | + self.group_size = 32 |
| 68 | + self.config = config |
| 69 | + self.dtype = dtype |
| 70 | + self.pre_dequantized = False |
| 71 | + |
| 72 | + # Validate dtype |
| 73 | + assert ( |
| 74 | + dtype in SUPPORTED_HIGHER_DTYPE |
| 75 | + ), f"Expected dtype to be one of {SUPPORTED_HIGHER_DTYPE}, but got {dtype}." |
| 76 | + |
| 77 | + # Initialize weights |
| 78 | + init_weight = self.initialize_weights(weight) |
| 79 | + self.register_buffer(self.weight_name, init_weight) |
| 80 | + |
| 81 | + # Initialize bias |
| 82 | + if bias is not None: |
| 83 | + if isinstance(bias, bool): |
| 84 | + bias = torch.zeros((out_features,), dtype=dtype) |
| 85 | + self.bias = torch.nn.Parameter(bias, requires_grad=False) |
| 86 | + else: |
| 87 | + self.register_parameter("bias", None) |
| 88 | + |
| 89 | + # Initialize weight scale |
| 90 | + init_weight_scale = ( |
| 91 | + torch.empty((out_features, in_features // self.group_size), dtype=torch.uint8) |
| 92 | + if weight_scale is None |
| 93 | + else weight_scale |
| 94 | + ) |
| 95 | + self.register_buffer("weight_scale", init_weight_scale) |
| 96 | + |
| 97 | + def initialize_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor: |
| 98 | + """ |
| 99 | + Initialize weights. This method should be overridden by subclasses. |
| 100 | + """ |
| 101 | + raise NotImplementedError("Subclasses must implement `initialize_weights`.") |
| 102 | + |
| 103 | + @classmethod |
| 104 | + def get_min_capability(cls) -> int: |
| 105 | + """ |
| 106 | + Get minimum device capability. |
| 107 | + """ |
| 108 | + logger.warning_once("MXFP quantization is still in experimental stage, the inference speed might be slow.") |
| 109 | + return 0 |
| 110 | + |
| 111 | + def dequant_mx_tensor( |
| 112 | + self, packed_data: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype = torch.float32 |
| 113 | + ) -> torch.Tensor: |
| 114 | + scale_float = self._get_float_scale(scale).to(target_dtype) |
| 115 | + unpacked_data = self.unpack_data(packed_data) |
| 116 | + original_shape = unpacked_data.shape |
| 117 | + unpacked_data = unpacked_data.reshape(-1, self.group_size) |
| 118 | + scale_float = scale_float.reshape(-1, 1) |
| 119 | + data_float = unpacked_data.to(target_dtype) |
| 120 | + data_dequant = data_float * scale_float |
| 121 | + data_dequant = data_dequant.reshape(original_shape) |
| 122 | + return data_dequant |
| 123 | + |
| 124 | + def dequant_weight_online(self): |
| 125 | + if self.pre_dequantized: |
| 126 | + return self.weight |
| 127 | + dq_weight = self.dequant_mx_tensor(self.weight, self.weight_scale) |
| 128 | + return dq_weight |
| 129 | + |
| 130 | + def pre_dequantize(self): |
| 131 | + if self.pre_dequantized: |
| 132 | + return |
| 133 | + dequant_weight = self.dequant_weight_online() |
| 134 | + delattr(self, self.weight_name) |
| 135 | + del self.weight_scale |
| 136 | + self.weight = torch.nn.Parameter(dequant_weight, requires_grad=False) |
| 137 | + self.pre_dequantized = True |
| 138 | + |
| 139 | + def qdq_input(self, activation: torch.Tensor): |
| 140 | + return _mx_qdq(activation, self.config) |
| 141 | + |
| 142 | + @classmethod |
| 143 | + def _get_float_scale(cls, scale_e8m0: torch.Tensor) -> torch.Tensor: |
| 144 | + return get_fp_scale(scale_e8m0) |
| 145 | + |
| 146 | + @torch.inference_mode() |
| 147 | + def forward(self, input: torch.Tensor) -> torch.Tensor: |
| 148 | + qdq_input = self.qdq_input(input) |
| 149 | + qdq_weight = self.dequant_weight_online() |
| 150 | + qdq_weight = qdq_weight.to(qdq_input.dtype) |
| 151 | + out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias) |
| 152 | + return out |
| 153 | + |
| 154 | + @classmethod |
| 155 | + def from_original(cls, config: Optional[QuantizationScheme], original_layer: torch.nn.Linear): |
| 156 | + """ |
| 157 | + Create an `MXQuantLinear` layer from an original linear layer. |
| 158 | + """ |
| 159 | + logger.warning_once("MXFP quantization is still in experimental stage, the inference speed might be slow.") |
| 160 | + qdq_linear = cls( |
| 161 | + in_features=original_layer.in_features, |
| 162 | + out_features=original_layer.out_features, |
| 163 | + config=config, |
| 164 | + bias=original_layer.bias, |
| 165 | + dtype=original_layer.weight.dtype, |
| 166 | + ) |
| 167 | + return qdq_linear |
| 168 | + |
| 169 | + |
| 170 | +class MXFP4QuantLinear(MXQuantLinearBase): |
| 171 | + """ |
| 172 | + Quantized linear layer using the MXFP4 quantization scheme. |
| 173 | + """ |
| 174 | + |
| 175 | + def __init__(self, *args, **kwargs): |
| 176 | + self.weight_name = "weight_packed" |
| 177 | + super().__init__(*args, **kwargs) |
| 178 | + |
| 179 | + def initialize_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor: |
| 180 | + weight_dtype = torch.uint8 |
| 181 | + weight_in_features = self.in_features // 2 |
| 182 | + return torch.zeros((self.out_features, weight_in_features), dtype=weight_dtype) if weight is None else weight |
| 183 | + |
| 184 | + def dequant_weight_online(self) -> torch.Tensor: |
| 185 | + if self.pre_dequantized: |
| 186 | + return self.weight |
| 187 | + dq_weight = self.dequant_mx_tensor(self.weight_packed, self.weight_scale) |
| 188 | + return dq_weight |
| 189 | + |
| 190 | + def unpack_data(self, packed_data: torch.Tensor) -> torch.Tensor: |
| 191 | + m, half_n = packed_data.shape |
| 192 | + unpacked_data = unpack_fp4_from_uint8(packed_data, m, half_n * 2, dtype=self.dtype) |
| 193 | + return unpacked_data |
| 194 | + |
| 195 | + |
| 196 | +class MXFP8QuantLinear(MXQuantLinearBase): |
| 197 | + """ |
| 198 | + Quantized linear layer using the MXFP8 quantization scheme. |
| 199 | + """ |
| 200 | + |
| 201 | + def __init__(self, *args, **kwargs): |
| 202 | + self.weight_name = "weight" |
| 203 | + super().__init__(*args, **kwargs) |
| 204 | + |
| 205 | + def initialize_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor: |
| 206 | + weight_dtype = torch.float8_e4m3fn |
| 207 | + weight_in_features = self.in_features |
| 208 | + return torch.zeros((self.out_features, weight_in_features), dtype=weight_dtype) if weight is None else weight |
| 209 | + |
| 210 | + def unpack_data(self, packed_data: torch.Tensor) -> torch.Tensor: |
| 211 | + return packed_data.to(self.dtype) |
0 commit comments