|
| 1 | +# Copyright (c) 2025 Intel Corporation |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +from typing import Optional, Union |
| 17 | + |
| 18 | +import torch |
| 19 | + |
| 20 | +from auto_round.experimental.qmodules.base import QModuleBase |
| 21 | +from auto_round.utils import logger |
| 22 | + |
| 23 | +__all__ = ["WeightFP8ActFP8StaticQuantLinear"] |
| 24 | + |
| 25 | + |
| 26 | +def _quant_tensor_to_fp8_with_scale(tensor: torch.Tensor, scale: torch.Tensor): |
| 27 | + FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max |
| 28 | + qtensor = tensor / scale |
| 29 | + clipped_qtensor = torch.clamp(qtensor, -FULL_RANGE, FULL_RANGE) |
| 30 | + clipped_qtensor_fp8 = clipped_qtensor.to(torch.float8_e4m3fn) |
| 31 | + return scale, clipped_qtensor_fp8 |
| 32 | + |
| 33 | + |
| 34 | +class WeightFP8ActFP8StaticQuantLinear(QModuleBase): |
| 35 | + hp_dtype = torch.bfloat16 |
| 36 | + fp8_dtype = torch.float8_e4m3fn |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + in_features, |
| 41 | + out_features, |
| 42 | + weight: Optional[torch.Tensor] = None, |
| 43 | + weight_scale: Optional[torch.Tensor] = None, |
| 44 | + bias: Union[torch.Tensor, bool, None] = None, |
| 45 | + input_scale: Optional[torch.Tensor] = None, |
| 46 | + dtype=torch.bfloat16, |
| 47 | + ): |
| 48 | + super().__init__() |
| 49 | + self.in_features = in_features |
| 50 | + self.out_features = out_features |
| 51 | + init_weight = torch.zeros((out_features, in_features), dtype=dtype) if weight is None else weight |
| 52 | + self.weight = torch.nn.Parameter(init_weight, requires_grad=False) |
| 53 | + self.dtype = dtype |
| 54 | + if bias is not None: |
| 55 | + if isinstance(bias, bool): |
| 56 | + bias = torch.zeros((out_features,), dtype=dtype) |
| 57 | + self.bias = torch.nn.Parameter(bias, requires_grad=False) |
| 58 | + else: |
| 59 | + self.register_parameter("bias", None) |
| 60 | + init_weight_scale = torch.empty((out_features), dtype=dtype) if weight_scale is None else weight_scale |
| 61 | + self.register_buffer("weight_scale", init_weight_scale.to(dtype)) |
| 62 | + |
| 63 | + init_input_scale = torch.zeros((1), dtype=dtype) if input_scale is None else input_scale |
| 64 | + self.register_buffer("input_scale", init_input_scale.to(dtype)) |
| 65 | + self.pre_dequantized = False |
| 66 | + |
| 67 | + @classmethod |
| 68 | + def get_min_capability(cls) -> int: |
| 69 | + """ |
| 70 | + Get minimum device capability. |
| 71 | + """ |
| 72 | + # TODO: correct that config once we add fp8 op support. |
| 73 | + logger.warning_once("FP8 ops are not yet supported. Using capability 0.") |
| 74 | + return 0 |
| 75 | + |
| 76 | + def process_weights_after_loading(self, layer: torch.nn.Module): |
| 77 | + pass |
| 78 | + |
| 79 | + @classmethod |
| 80 | + def from_original(cls, config, original_layer): |
| 81 | + """ |
| 82 | + Create an `WeightFP8ActFP8StaticQuantLinear` layer from an original linear layer. |
| 83 | + """ |
| 84 | + logger.warning_once( |
| 85 | + "FP8 static quantization is still in experimental stage, the inference speed might be slow." |
| 86 | + ) |
| 87 | + device = original_layer.weight.device |
| 88 | + with torch.device(device): |
| 89 | + qdq_linear = cls( |
| 90 | + in_features=original_layer.in_features, |
| 91 | + out_features=original_layer.out_features, |
| 92 | + bias=original_layer.bias, |
| 93 | + ) |
| 94 | + return qdq_linear |
| 95 | + |
| 96 | + def dequant_weight_online(self): |
| 97 | + if self.pre_dequantized: |
| 98 | + return self.weight |
| 99 | + qdq_weight = self.weight.to(self.dtype) * self.weight_scale.unsqueeze(1) |
| 100 | + return qdq_weight |
| 101 | + |
| 102 | + def pre_dequantize(self): |
| 103 | + if self.pre_dequantized: |
| 104 | + return |
| 105 | + dequant_weight = self.dequant_weight_online() |
| 106 | + del self.weight |
| 107 | + del self.weight_scale |
| 108 | + self.weight = torch.nn.Parameter(dequant_weight, requires_grad=False) |
| 109 | + self.pre_dequantized = True |
| 110 | + |
| 111 | + def qdq_input(self, bf16_input: torch.Tensor): |
| 112 | + input_scale, input_fp8 = _quant_tensor_to_fp8_with_scale(bf16_input, self.input_scale.data) |
| 113 | + qdq_input_bf16 = input_fp8.to(self.dtype) * input_scale |
| 114 | + return qdq_input_bf16 |
| 115 | + |
| 116 | + @torch.no_grad() |
| 117 | + def forward(self, bf16_input: torch.Tensor) -> torch.Tensor: |
| 118 | + |
| 119 | + qdq_input = self.qdq_input(bf16_input) |
| 120 | + qdq_weight = self.dequant_weight_online() |
| 121 | + out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias) |
| 122 | + return out |
0 commit comments