Skip to content

Commit 09e4d31

Browse files
yiliu30Copilotpre-commit-ci[bot]
authored
Support loading for static quant weight fp8 act fp8 (#730)
Signed-off-by: yiliu30 <yi4.liu@intel.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1452d95 commit 09e4d31

File tree

10 files changed

+393
-66
lines changed

10 files changed

+393
-66
lines changed

auto_round/autoround.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import time
2020
import traceback
2121
from dataclasses import asdict, fields
22+
from enum import Enum
2223
from typing import Any, Callable, Union
2324

2425
import accelerate
@@ -30,6 +31,7 @@
3031

3132
from auto_round.data_type import QUANT_FUNC_WITH_DTYPE
3233
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size
34+
from auto_round.export.export_to_autoround import AutoRoundFormat
3335
from auto_round.export.export_to_gguf.config import GGUF_CONFIG, GGUF_INNER_CONFIG, ModelType
3436
from auto_round.low_cpu_mem.utils import get_layers_before_block
3537
from auto_round.schemes import QuantizationScheme, preset_name_to_scheme
@@ -857,8 +859,8 @@ def remove_duplicates(lst):
857859
format = "auto_round:auto_awq"
858860
elif is_nv_fp(self.data_type) or is_mx_fp(self.data_type):
859861
format = f"auto_round:{self.data_type}"
860-
elif is_wfp8afp8(self): # staic wfp8afp8
861-
format = "auto_round:fp8"
862+
elif is_static_wfp8afp8(self): # staic wfp8afp8
863+
format = f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}"
862864
elif self.data_type == "fp" and self.bits == 8 and self.act_bits >= 16: # woq fp8
863865
format = "auto_round:fp8"
864866
elif self.act_bits < 16:
@@ -956,10 +958,10 @@ def _check_supported_format(self, format: str) -> bool:
956958
)
957959
format = "fake"
958960
else:
959-
if not (format == "auto_round" or format == "auto_round:fp8"):
961+
if not (format == "auto_round" or format == f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}"):
960962
logger.warning(
961963
f"Currently only support to export auto_round or fake format for static W{self.bits}AFP8 model,"
962-
" change format to auto_round"
964+
f" change format {format} to auto_round"
963965
)
964966
format = "auto_round"
965967
if self.act_group_size != 0 and not self.act_dynamic and format == "auto_round:fp8":
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
from typing import Optional, Union
17+
18+
import torch
19+
20+
__all__ = ["QModuleBase"]
21+
22+
23+
class QModuleBase(torch.nn.Module):
24+
"""
25+
Base class used to describe the weight creation and forward pass
26+
of different quantization schemes supported by Auto-Round.
27+
The design is inspired by vLLM's CompressedTensorsScheme:
28+
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
29+
30+
"""
31+
32+
def __init__(self):
33+
super().__init__()
34+
35+
@classmethod
36+
@abstractmethod
37+
def from_original(cls, config, original_layer: torch.nn.Module):
38+
raise NotImplementedError
39+
40+
@classmethod
41+
@abstractmethod
42+
def get_min_capability(cls) -> int:
43+
"""
44+
Get minimum device capability.
45+
"""
46+
raise NotImplementedError
47+
48+
@abstractmethod
49+
def process_weights_after_loading(self, layer: torch.nn.Module):
50+
"""
51+
Called after weight loading is complete for any cleanup that
52+
needs to occur.
53+
"""
54+
raise NotImplementedError
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from typing import Optional, Union
17+
18+
import torch
19+
20+
from auto_round.experimental.qmodules.base import QModuleBase
21+
from auto_round.utils import logger
22+
23+
__all__ = ["WeightFP8ActFP8StaticQuantLinear"]
24+
25+
26+
def _quant_tensor_to_fp8_with_scale(tensor: torch.Tensor, scale: torch.Tensor):
27+
FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max
28+
qtensor = tensor / scale
29+
clipped_qtensor = torch.clamp(qtensor, -FULL_RANGE, FULL_RANGE)
30+
clipped_qtensor_fp8 = clipped_qtensor.to(torch.float8_e4m3fn)
31+
return scale, clipped_qtensor_fp8
32+
33+
34+
class WeightFP8ActFP8StaticQuantLinear(QModuleBase):
35+
hp_dtype = torch.bfloat16
36+
fp8_dtype = torch.float8_e4m3fn
37+
38+
def __init__(
39+
self,
40+
in_features,
41+
out_features,
42+
weight: Optional[torch.Tensor] = None,
43+
weight_scale: Optional[torch.Tensor] = None,
44+
bias: Union[torch.Tensor, bool, None] = None,
45+
input_scale: Optional[torch.Tensor] = None,
46+
dtype=torch.bfloat16,
47+
):
48+
super().__init__()
49+
self.in_features = in_features
50+
self.out_features = out_features
51+
init_weight = torch.zeros((out_features, in_features), dtype=dtype) if weight is None else weight
52+
self.weight = torch.nn.Parameter(init_weight, requires_grad=False)
53+
self.dtype = dtype
54+
if bias is not None:
55+
if isinstance(bias, bool):
56+
bias = torch.zeros((out_features,), dtype=dtype)
57+
self.bias = torch.nn.Parameter(bias, requires_grad=False)
58+
else:
59+
self.register_parameter("bias", None)
60+
init_weight_scale = torch.empty((out_features), dtype=dtype) if weight_scale is None else weight_scale
61+
self.register_buffer("weight_scale", init_weight_scale.to(dtype))
62+
63+
init_input_scale = torch.zeros((1), dtype=dtype) if input_scale is None else input_scale
64+
self.register_buffer("input_scale", init_input_scale.to(dtype))
65+
self.pre_dequantized = False
66+
67+
@classmethod
68+
def get_min_capability(cls) -> int:
69+
"""
70+
Get minimum device capability.
71+
"""
72+
# TODO: correct that config once we add fp8 op support.
73+
logger.warning_once("FP8 ops are not yet supported. Using capability 0.")
74+
return 0
75+
76+
def process_weights_after_loading(self, layer: torch.nn.Module):
77+
pass
78+
79+
@classmethod
80+
def from_original(cls, config, original_layer):
81+
"""
82+
Create an `WeightFP8ActFP8StaticQuantLinear` layer from an original linear layer.
83+
"""
84+
logger.warning_once(
85+
"FP8 static quantization is still in experimental stage, the inference speed might be slow."
86+
)
87+
device = original_layer.weight.device
88+
with torch.device(device):
89+
qdq_linear = cls(
90+
in_features=original_layer.in_features,
91+
out_features=original_layer.out_features,
92+
bias=original_layer.bias,
93+
)
94+
return qdq_linear
95+
96+
def dequant_weight_online(self):
97+
if self.pre_dequantized:
98+
return self.weight
99+
qdq_weight = self.weight.to(self.dtype) * self.weight_scale.unsqueeze(1)
100+
return qdq_weight
101+
102+
def pre_dequantize(self):
103+
if self.pre_dequantized:
104+
return
105+
dequant_weight = self.dequant_weight_online()
106+
del self.weight
107+
del self.weight_scale
108+
self.weight = torch.nn.Parameter(dequant_weight, requires_grad=False)
109+
self.pre_dequantized = True
110+
111+
def qdq_input(self, bf16_input: torch.Tensor):
112+
input_scale, input_fp8 = _quant_tensor_to_fp8_with_scale(bf16_input, self.input_scale.data)
113+
qdq_input_bf16 = input_fp8.to(self.dtype) * input_scale
114+
return qdq_input_bf16
115+
116+
@torch.no_grad()
117+
def forward(self, bf16_input: torch.Tensor) -> torch.Tensor:
118+
119+
qdq_input = self.qdq_input(bf16_input)
120+
qdq_weight = self.dequant_weight_online()
121+
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
122+
return out

auto_round/export/export_to_autoround/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .export import save_quantized_as_autoround
15+
from .export import save_quantized_as_autoround, AutoRoundFormat

auto_round/export/export_to_autoround/export.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import os
2020
from concurrent.futures import ThreadPoolExecutor
21+
from enum import Enum
2122

2223
import threadpoolctl as tctl
2324
import torch
@@ -43,6 +44,12 @@
4344
)
4445

4546

47+
class AutoRoundFormat(str, Enum):
48+
# Weight: FP8, per-channel, may be extended to per-tensor in future
49+
# Activation: FP8, per-tensor
50+
TORCH_FP8_STATIC = "fp8_static"
51+
52+
4653
def dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_bits=16):
4754
"""
4855
Dynamically imports and returns the appropriate QuantLinear class based on the specified backend and parameters.
@@ -152,7 +159,7 @@ def pack_layer(layer_name, model, backend, device=None):
152159

153160
return pack_layer(layer_name, model, backend, device)
154161

155-
if backend == "auto_round:fp8":
162+
if backend == "auto_round:fp8" or backend == f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}":
156163
from auto_round.export.export_to_autoround.export_to_fp8 import pack_layer
157164

158165
return pack_layer(layer_name, model, backend, device)
@@ -268,9 +275,14 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
268275
from auto_round.export.export_to_autoround.export_to_fp8 import save_quantized_as_autoround
269276

270277
return save_quantized_as_autoround(output_dir, inplace=inplace, backend="auto_round", **kwargs)
278+
from auto_round.autoround import AutoRoundFormat
271279

272280
##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source
273-
if (kwargs.get("sym") is None or kwargs.get("sym")) and ("gptq" not in backend and "awq" not in backend):
281+
if (
282+
(kwargs.get("sym") is None or kwargs.get("sym"))
283+
and ("gptq" not in backend and "awq" not in backend)
284+
and (AutoRoundFormat.TORCH_FP8_STATIC.value not in backend)
285+
):
274286
backend = backend.replace("auto_round", "auto_round:auto_gptq")
275287

276288
model = kwargs["model"]

0 commit comments

Comments
 (0)