diff --git a/llm/README.md b/llm/README.md index ff31eb345328..e0c4cb95d04e 100644 --- a/llm/README.md +++ b/llm/README.md @@ -224,13 +224,13 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo ```shell # PTQ 量化启动命令参考 -python run_finetune.py ./config/llama/ptq_argument.json +python run_finetune.py ./config/llama/ptq_argument.json # GPTQ 量化启动命令参考 -python run_finetune.py ./config/llama/ptq_argument.json +python run_finetune.py ./config/llama/ptq_argument.json # W8A8C8(INT)量化启动命令参考 -python run_finetune.py ./config/llama/ptq_c8_argument.json +python run_finetune.py ./config/llama/ptq_c8_argument.json # W8A8(FP8)量化启动命令参考 python run_finetune.py ./config/llama/fp8_ptq_argument.json diff --git a/llm/docs/quantization.md b/llm/docs/quantization.md index 6e45c8c781ac..5689e30c0b81 100644 --- a/llm/docs/quantization.md +++ b/llm/docs/quantization.md @@ -1,4 +1,4 @@ -# 大模型量化教程 +p# 大模型量化教程 ## 1.算法介绍 @@ -111,8 +111,8 @@ python run_finetune.py ./config/llama/ceval_quant_argument.json - `use_fp8`: 是否使用 FP8 量化,默认为空字符串。输入`"WA"`(不区分大小写)则将权重和激活的8位量化转换为 FP8量化。 - `fp8_type`: FP8量化类型,长度应与`use_fp8`相同。默认为`["e4m3","e4m3"]`。 - `do_ptq`: 是否进行 PTQ 量化,默认为 False。 -- `weight_quant_method`: 权重量化方式,现可选 groupwise 或者 abs_max_channel_wise。 -- `act_quant_method`: 激活量化方式,现可选 avg 或者 abs_max。 +- `weight_quant_method`: 权重量化方式,INT8量化可选 groupwise 或者 abs_max_channel_wise,FP8量化可选 abs_max 或 avg。 +- `act_quant_method`: 激活量化方式,INT8可选 avg 或者 abs_max,FP8量化可选 abs_max 或 avg。 - `cachekv_quant_method`: kvcache 量化方式,现可选 abs_max_headwise, avg_headwise。 - `ptq_step`: PTQ 量化步数,也即模型前向次数,默认为32。 - `shift`: 是否在 PTQ 量化前进行[Shift 策略](https://arxiv.org/abs/2304.09145),默认为 False。使用 Shift 策略需要设`do_ptq`为 True。 diff --git a/llm/experimental/ceval/default/eval.py b/llm/experimental/ceval/default/eval.py index 5980ecbbf16b..386784af2d89 100644 --- a/llm/experimental/ceval/default/eval.py +++ b/llm/experimental/ceval/default/eval.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,7 +36,6 @@ def run_eval_one_time(args, evaluator, take): subject_list = [val_file.replace("_val.csv", "") for val_file in filenames] accuracy, summary = {}, {} - # run_date = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(time.time())) output_dir = args.output_dir save_result_dir = os.path.join(output_dir, f"take{take}") if not os.path.exists(save_result_dir): @@ -44,9 +43,6 @@ def run_eval_one_time(args, evaluator, take): all_answers = {} for index, subject_name in enumerate(subject_list): - # print( - # f"{index/len(subject_list)} Inference starts at {run_date} on {args.model_name_or_path} with subject of {subject_name}!" - # ) val_file_path = os.path.join(val_path, f"{subject_name}_val.csv") dev_file_path = os.path.join(dev_path, f"{subject_name}_dev.csv") test_file_path = os.path.join(test_path, f"{subject_name}_test.csv") @@ -54,7 +50,6 @@ def run_eval_one_time(args, evaluator, take): val_df = pd.read_csv(val_file_path) if args.do_test is False else pd.read_csv(test_file_path) dev_df = pd.read_csv(dev_file_path) if args.few_shot else None - # import pdb;pdb.set_trace() correct_ratio, answers = evaluator.eval_subject( subject_name, val_df, diff --git a/llm/experimental/ceval/default/evaluator.py b/llm/experimental/ceval/default/evaluator.py index 47eff428b9ba..db66e7e4db76 100644 --- a/llm/experimental/ceval/default/evaluator.py +++ b/llm/experimental/ceval/default/evaluator.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/llm/experimental/ceval/default/model_evaluator.py b/llm/experimental/ceval/default/model_evaluator.py index 3f1243b7a10a..4b0dd0d3d5c4 100644 --- a/llm/experimental/ceval/default/model_evaluator.py +++ b/llm/experimental/ceval/default/model_evaluator.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/llm/experimental/layers/cache_kv.py b/llm/experimental/layers/cache_kv.py index 06200f1591b7..e159ae8f5096 100644 --- a/llm/experimental/layers/cache_kv.py +++ b/llm/experimental/layers/cache_kv.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -166,13 +166,7 @@ def forward( def _smooth(self, x, y, use_smooth_x): # For ShiftSmooth - # smooth_shape = y.shape[2:] self.dtype = y.dtype - # if not hasattr(self, "smooth_weight"): - # self.smooth_weight = self.create_parameter( - # shape=smooth_shape, - # attr=ParamAttr(initializer=Constant(value=1.)), - # dtype=self.dtype) smooth_y = y smooth_y = paddle.divide(smooth_y, self.smooth_weight) diff --git a/llm/experimental/layers/custom_attention.py b/llm/experimental/layers/custom_attention.py index 4c7a12016fbb..c40c815b3346 100644 --- a/llm/experimental/layers/custom_attention.py +++ b/llm/experimental/layers/custom_attention.py @@ -58,7 +58,6 @@ def forward( **kwargs ): """forward""" - # import pdb;pdb.set_trace() if self.enable_fake_quant: self.collect_kv_quant_policy(q, k, v, **kwargs) perm = [0, 2, 1, 3] # [1, 2, 0, 3] if self.sequence_parallel else [0, 2, 1, 3] diff --git a/llm/experimental/observer/abs_max_headwise.py b/llm/experimental/observer/abs_max_headwise.py index ce76e14b1d76..500fbfa1ff55 100644 --- a/llm/experimental/observer/abs_max_headwise.py +++ b/llm/experimental/observer/abs_max_headwise.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +14,6 @@ import numpy as np import paddle - -# from paddleslim.quant.observers.channel_wise import ChannelWiseObserver from experimental.observer.channel_wise import ChannelWiseObserver from paddle.quantization.factory import ObserverFactory diff --git a/llm/experimental/observer/avg_headwise.py b/llm/experimental/observer/avg_headwise.py index 36c97f99ebfa..a25fbd770019 100644 --- a/llm/experimental/observer/avg_headwise.py +++ b/llm/experimental/observer/avg_headwise.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/llm/experimental/observer/channel_wise.py b/llm/experimental/observer/channel_wise.py index 3a7ab973d014..883a74a8f9b0 100644 --- a/llm/experimental/observer/channel_wise.py +++ b/llm/experimental/observer/channel_wise.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,13 +14,8 @@ from typing import Dict -# import numpy as np import paddle - -# from paddle.quantization.factory import ObserverFactory from experimental.layers.cache_kv import CacheKVMatMul - -# from paddleslim.quant.observers.mse import MSEObserverLayer from paddleslim.quant.observers.uniform import UniformObserver CHANNEL_AXIS: Dict[type, int] = { diff --git a/llm/experimental/observer/mse.py b/llm/experimental/observer/mse.py deleted file mode 100644 index e9a5b86366c1..000000000000 --- a/llm/experimental/observer/mse.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# import numpy as np -import paddle -from paddle.nn.quant.format import LinearDequanter, LinearQuanter -from paddle.quantization.factory import ObserverFactory -from paddleslim.quant.observers.uniform import UniformObserver - - -class MSEObserver(ObserverFactory): - r""" - It collects maximum absolute values of target tensor. - Args: - bit_length(int, optional): Number of bits to represent an quantized integer in binary. - dtype(str, optional): The data type of input tensor. - name (str, optional): This parameter is used by developers to print debugging information. \ - For details, please refer to :ref:`api_guide_Name`. Default is None. - Examples: - .. code-block:: python - from paddle.quantization import QuantConfig - from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver - quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99) - q_config = QuantConfig(activation=quanter, weight=quanter) - """ - - def __init__(self, quant_bits=8): - super(MSEObserver, self).__init__(quant_bits=quant_bits) - - def _get_class(self): - return MSEObserverLayer - - -class MSEObserverLayer(UniformObserver): - def __init__(self, layer, quant_bits=8, moving_avg=False): - super(MSEObserverLayer, self).__init__(quant_bits=quant_bits) - self.quant_bits = quant_bits - self.calibration_loss = float("inf") - self.qmin, self.qmax = self.qmin_qmax - - self._current_iters = 0 - self._range_update_factor_min = 0.001 - - self._moving_avg = moving_avg - self._max = None - self.observer_enabled = True - - def forward(self, inputs): - """Calculate forward pass.""" - self._scale = None - self._zero_point = None - self._min = None - self._max = None - - if self.observer_enabled: - self._max = self.cal_abs_max(inputs) - - return inputs - - def cal_abs_max(self, inputs): - self._current_iters += 1 - # abs_max_value = float(paddle.max(paddle.abs(inputs.flatten()))) - abs_max_value = paddle.max(paddle.mean(paddle.abs(inputs.flatten()), axis=0)) # average over batch - abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value - s = 0.3 - scale_mse = abs_max_value - while s <= 1.0: - scale = s * abs_max_value - s += 0.02 - - quant_var_func = LinearQuanter(scale, 0.0, bit_length=self.quant_bits) - dequant_var_func = LinearDequanter(scale, 0.0, bit_length=self.quant_bits) - - quant_var = quant_var_func(inputs) - quant_dequant_var = dequant_var_func(quant_var) - - mse_loss = ((inputs - quant_dequant_var) ** 2).mean() - if mse_loss <= self.calibration_loss: - self.calibration_loss = mse_loss - scale_mse = scale - - # import pdb;pdb.set_trace() - if self._moving_avg and self._max is not None: - update_factor = 1.0 / self._current_iters - update_factor = max(update_factor, self._range_update_factor_min) - scale_mse = self._max * (1 - update_factor) + scale_mse * update_factor - - return scale_mse - - def cal_thresholds(self): - """Compute thresholds for MAX function.""" - if self._scale is not None: - self._zero_point = 0 - return - self._scale = self._max - self._zero_point = 0 - - def min_value(self) -> float: - return 0 - - def max_value(self) -> float: - return self._max - - def bit_length(self): - """Return the bit length of quantized data.""" - return self._quant_bits - - def quant_axis(self): - """Return quantization axis.""" - return -1 - - def scales(self): - """Return output scales.""" - if self._scale is None: - self.cal_thresholds() - return self._scale - - def zero_points(self): - """Return output zero points.""" - if self._zero_point is None: - self.cal_thresholds() - return self._zero_point diff --git a/llm/experimental/observer/static.py b/llm/experimental/observer/static.py deleted file mode 100644 index 518cf6645063..000000000000 --- a/llm/experimental/observer/static.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# import numpy as np -import paddle -from paddle.quantization.factory import ObserverFactory -from paddleslim.quant.observers.uniform import UniformObserver - - -class StaticObserver(ObserverFactory): - r""" - It collects maximum absolute values of target tensor. - Args: - bit_length(int, optional): Number of bits to represent an quantized integer in binary. - dtype(str, optional): The data type of input tensor. - name (str, optional): This parameter is used by developers to print debugging information. \ - For details, please refer to :ref:`api_guide_Name`. Default is None. - Examples: - .. code-block:: python - from paddle.quantization import QuantConfig - from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver - quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99) - q_config = QuantConfig(activation=quanter, weight=quanter) - """ - - def __init__(self, quant_bits=8): - super(StaticObserver, self).__init__(quant_bits=quant_bits) - - def _get_class(self): - return StaticObserverLayer - - -class StaticObserverLayer(UniformObserver): - def __init__(self, layer, quant_bits=8, static_val=448): - super(StaticObserverLayer, self).__init__(quant_bits=quant_bits) - self._quant_bits = quant_bits - self._avg_list = [] - self.static_val = static_val - - def forward(self, inputs): - """Calculate forward pass.""" - - self._scale = None - self._zero_point = None - self._min, self._max = self.cal_min_max(inputs) - return inputs - - def cal_min_max(self, inputs): - max_val = paddle.to_tensor(self.static_val).astype(inputs.dtype).to(inputs.place) - return -max_val, max_val - - def cal_thresholds(self): - """Compute thresholds for MAX function.""" - - self._scale, self._zero_point = self._max, 0 - - def min_value(self) -> float: - return self._min - - def max_value(self) -> float: - return self._max - - def bit_length(self): - """Return the bit length of quantized data.""" - return self._quant_bits - - def quant_axis(self): - """Return quantization axis.""" - return -1 - - def scales(self): - """Return output scales.""" - if self._scale is None: - self.cal_thresholds() - return self._scale - - def zero_points(self): - """Return output zero points.""" - if self._zero_point is None: - self.cal_thresholds() - return self._zero_point diff --git a/llm/utils/quant.py b/llm/utils/quant.py index 00e65edf34d9..b9520e72edc3 100644 --- a/llm/utils/quant.py +++ b/llm/utils/quant.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -213,21 +213,25 @@ def prepare_qconfig(args): """ Prepare qconfig """ - - weight_observer = WEIGHT_OBSERVER.get(args.weight_quant_method, None) - if weight_observer is None: - weight_observer = FP8_OBSERVER.get(args.weight_quant_method, None) - - act_observer = ACT_OBSERVER.get(args.act_quant_method, None) - if act_observer is None: - act_observer = FP8_OBSERVER.get(args.act_quant_method, None) - - cachekv_observer = CACHEKV_OBSERVER.get(args.cachekv_quant_method, None) - if cachekv_observer is None: - cachekv_observer = FP8_OBSERVER.get(args.cachekv_quant_method, None) - args.quant_type = args.quant_type.lower() args.use_fp8 = args.use_fp8.lower() + + weight_observer = ( + WEIGHT_OBSERVER.get(args.weight_quant_method, None) + if "w" not in args.use_fp8 + else FP8_OBSERVER.get(args.weight_quant_method, None) + ) + act_observer = ( + ACT_OBSERVER.get(args.act_quant_method, None) + if "a" not in args.use_fp8 + else FP8_OBSERVER.get(args.act_quant_method, None) + ) + cachekv_observer = ( + CACHEKV_OBSERVER.get(args.cachekv_quant_method, None) + if "c" not in args.use_fp8 + else FP8_OBSERVER.get(args.cachekv_quant_method, None) + ) + if "c8" in args.quant_type: quant_type = args.quant_type.replace("c8", "") cachekv_quant = True @@ -252,8 +256,8 @@ def prepare_qconfig(args): a_quant_bit = (4, 3) if args.fp8_type[args.use_fp8.index("a")] == "e4m3" else (5, 2) else: a_quant_bit = 8 - activation = act_observer(quant_bits=w_quant_bit) - weight = weight_observer(quant_bits=a_quant_bit) + activation = act_observer(quant_bits=a_quant_bit) + weight = weight_observer(quant_bits=w_quant_bit) elif quant_type in ["wint4", "w4a16", "weight_only_int8"]: activation = None @@ -290,7 +294,7 @@ def prepare_qconfig(args): q_config.add_qat_layer_mapping(FuncWrapper, QuantizedCustomAttentionLayer) elif cachekv_quant_bits == "fp8": - cachekv_quant_bit = (4, 3) if args.fp8_type[args.use_fp8.index("C")] == "e4m3" else (5, 2) + cachekv_quant_bit = (4, 3) if args.fp8_type[args.use_fp8.index("c")] == "e4m3" else (5, 2) if "headwise" in args.cachekv_quant_method: cachekv = [