Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2024-05-30 Add FP8 PTQ #1877

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
54 changes: 52 additions & 2 deletions docs/zh_cn/tutorials/quant/post_training_quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@

### 1. 量化配置相关概念以及接口:

`Observer`:用于统计OP输入或输出,并计算出量化相关的统计量,比如scale、zero_point等。每个离线量化算法对应一个Observer,现已有的Observer包含:
`Observer`:用于统计OP输入或输出,并计算出量化相关的统计量,比如scale、zero_point等。每个离线量化算法对应一个Observer,Observer可以使用属性quant_bits调整量化的数据类型,quant_bits = 8代表INT8量化,quant_bits = (4,3)代表FP8量化,现已有的Observer包含:
- `AVGObserver`:收集目标Tensor的平均值作为量化scale
- `MSEObserver`:收集最大绝对值并通过最小化MSE误差,收集量化scale
- `EMDObserver`:收集最大绝对值并通过最小化EMD误差,收集量化scale
- `HistObserver`:将张量值收集到直方图中,并根据百分比计算量化scale
- `KLObserver`:以最小化浮点值分布与量化浮点值分布之间的 Kullback-Leibler散度计算量化scale
- `AbsmaxObserver`:根据目标权重的Tensor维度,收集最大绝对值作为量化scale
- `AbsMaxChannelWiseWeightObserver`:根据目标权重的通道维度,收集最大绝对值作为量化scale
- `MSEChannelWiseWeightObserver`:根据目标权重的通道维度,收集最大绝对值并通过最小化MSE误差,收集量化scale

Expand All @@ -44,7 +45,7 @@
| convert | `model`:需要被转化的量化模型 <br> `inplace`:inplace=True时,该模型会被inplace的量化;inplace=False时,不改变原模型,并且会return一个量化的模型 | 将模型转化成onnx形式,进行此步骤之后才能对量化模型进行验证、导出成静态图等


## 使用示例
## INT8量化使用示例
```python
import paddle
import paddleslim
Expand Down Expand Up @@ -91,3 +92,52 @@ for step, data in enumerate(dataloader):
# convert to quant model that can evaluate and export
model = ptq.convert(model, inplace=True)
```


## FP8量化使用示例
```python
import paddle
import paddleslim
from paddle.vision.models import mobilenet_v1
from paddle.quantization import QuantConfig
from paddle.quantization import PTQ
from paddleslim.quant.observers import HistObserver, KLObserver, EMDObserver, MSEObserver, AVGObserver, MSEChannelWiseWeightObserver, AbsMaxChannelWiseWeightObserver

# create the model
model = mobilenet_v1()

# define QuantConfig
q_config = QuantConfig(activation=None, weight=None)

# define act_quanter and weight_quanter
act_quanter = AbsmaxObserver(quant_bits=(4,3))
weight_quanter = AbsMaxChannelWiseWeightObserver(quant_bits=(4,3))

# map ColumnParallelLinear to QuantizedColumnParallelLinear
q_config.add_qat_layer_mapping(ColumnParallelLinear,
QuantizedColumnParallelLinear)
# map RowParallelLinear to QuantizedRowParallelLinear
q_config.add_qat_layer_mapping(RowParallelLinear,
QuantizedRowParallelLinear)
# for each layer if type in [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear]
# make them quantizable
q_config.add_type_config(
[paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear],
activation=activation,
weight=weight,
)


ptq = PTQ(q_config)
model = ptq.quantize(model, inplace=True)

# ptq sample
ptq_step = 100
for step, data in enumerate(dataloader):
pred = model(data)
if step == ptq_step:
break

# convert to quant model that can evaluate and export
model = ptq.convert(model, inplace=True)
```
7 changes: 5 additions & 2 deletions paddleslim/quant/observers/base_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def forward(self, inputs):
self._zero_point = None
self._min = None
self._max = None

"""" Cast inputs to 'float32' for numpy compatibility in _init_hists function, avoiding issues with types like bf16.
"""
dtype = inputs.dtype
inputs = inputs.cast('float32')
if self._hist_min is None or self._hist_max is None:
self._hist_min, self._hist_max = self._min_max(inputs)
self._hist = self._init_hists(inputs)
Expand All @@ -82,7 +85,7 @@ def forward(self, inputs):
self._upsample_bin_count, )
self._hist_min, self._hist_max = new_min, new_max
self._hist = new_hist
return inputs
return inputs.cast(dtype)

def _update_min_max_and_hist(self, tensor, origin_min, origin_max,
origin_hist, bins_count, upsample_bins_count):
Expand Down
1 change: 0 additions & 1 deletion paddleslim/quant/observers/emd.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def cal_min_max(self, inputs):
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2**(self._quant_bits - 1) - 1
quant_var = paddle.clip(
paddle.round(inputs / scale * self.qmax), -self.qmax - 1,
self.qmax)
Expand Down
2 changes: 2 additions & 0 deletions paddleslim/quant/observers/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def cal_kl_threshold(hist, bin_width, bits):
assert hist.ndim == 1
hist_bins = hist.shape[0]
starting_iter = int((hist_bins - 1) * 0.5)
if isinstance(bits,tuple):
bits = bits[0] + bits[1]
quant_range = 2**(bits - 1) - 1

P_sum = np.sum(np.array(hist).ravel())
Expand Down
26 changes: 19 additions & 7 deletions paddleslim/quant/observers/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class UniformObserver(BaseObserver):
an integer value ensuring that zero is quantized without error.

Args:
quant_bits (int): The number of bits for quantization.
quant_bits (int) or (Tuple): The number of bits for quantization.
sign (bool): Whether the quantized integer includes a sign.
symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric.
In symmetric quantization, the range of floating point values is relaxed to be symmetric
Expand Down Expand Up @@ -56,12 +56,24 @@ def __init__(
def qmin_qmax(self):
""" Calculate the range of the quantized integer based on the specified
quant_bits, sign, and symmetric properties."""
if self._sign:
self._qmin = -2**(self.bit_length() - 1)
self._qmax = 2**(self.bit_length() - 1) - 1
else:
self._qmin = 0
self._qmax = 2**self.bit_length()
if isinstance(self._quant_bits,tuple):
if (self._quant_bits[0]==4 and self._quant_bits[1]==3 and len(self._quant_bits)==2):
self._qmin = -448.0
self._qmax = 448.0
elif (self._quant_bits[0]==5 and self._quant_bits[1]==2 and len(self._quant_bits)==2):
self._qmin = 57344.0
self._qmax = 57344.0
else:
raise NotImplementedError(
"Currently, only float8_e4m3 and float8_e5m2 formats are supported. Please set quant_bits to (4,3) or (5,2) for the corresponding format."
)
else:
if self._sign:
self._qmin = -2**(self.bit_length() - 1)
self._qmax = 2**(self.bit_length() - 1) - 1
else:
self._qmin = 0
self._qmax = 2**self.bit_length()
return self._qmin, self._qmax

@abc.abstractmethod
Expand Down