Skip to content

Commit

Permalink
[Inference] update fakequant support (#9047)
Browse files Browse the repository at this point in the history
* 1. add a8w8(fp8) a8w8c8(int8) quant_type support
2. add llama3.1 and qwen2 ptq config
3. update quantization.md

* fix load_quant_model bug

* fix load quant bug

* update ll/README.md

* remove useless code

* update quant observer config

* resolve wrong modify

* fix prepare_qconfig

* remove unuse files
  • Loading branch information
lixcli authored Aug 29, 2024
1 parent c28caf7 commit e0ba7ef
Show file tree
Hide file tree
Showing 13 changed files with 34 additions and 274 deletions.
6 changes: 3 additions & 3 deletions llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo

```shell
# PTQ 量化启动命令参考
python run_finetune.py ./config/llama/ptq_argument.json
python run_finetune.py ./config/llama/ptq_argument.json

# GPTQ 量化启动命令参考
python run_finetune.py ./config/llama/ptq_argument.json
python run_finetune.py ./config/llama/ptq_argument.json

# W8A8C8(INT)量化启动命令参考
python run_finetune.py ./config/llama/ptq_c8_argument.json
python run_finetune.py ./config/llama/ptq_c8_argument.json

# W8A8(FP8)量化启动命令参考
python run_finetune.py ./config/llama/fp8_ptq_argument.json
Expand Down
6 changes: 3 additions & 3 deletions llm/docs/quantization.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# 大模型量化教程
p# 大模型量化教程

## 1.算法介绍

Expand Down Expand Up @@ -111,8 +111,8 @@ python run_finetune.py ./config/llama/ceval_quant_argument.json
- `use_fp8`: 是否使用 FP8 量化,默认为空字符串。输入`"WA"`(不区分大小写)则将权重和激活的8位量化转换为 FP8量化。
- `fp8_type`: FP8量化类型,长度应与`use_fp8`相同。默认为`["e4m3","e4m3"]`
- `do_ptq`: 是否进行 PTQ 量化,默认为 False。
- `weight_quant_method`: 权重量化方式,现可选 groupwise 或者 abs_max_channel_wise。
- `act_quant_method`: 激活量化方式,现可选 avg 或者 abs_max。
- `weight_quant_method`: 权重量化方式,INT8量化可选 groupwise 或者 abs_max_channel_wise,FP8量化可选 abs_max 或 avg
- `act_quant_method`: 激活量化方式,INT8可选 avg 或者 abs_max,FP8量化可选 abs_max 或 avg
- `cachekv_quant_method`: kvcache 量化方式,现可选 abs_max_headwise, avg_headwise。
- `ptq_step`: PTQ 量化步数,也即模型前向次数,默认为32。
- `shift`: 是否在 PTQ 量化前进行[Shift 策略](https://arxiv.org/abs/2304.09145),默认为 False。使用 Shift 策略需要设`do_ptq`为 True。
Expand Down
7 changes: 1 addition & 6 deletions llm/experimental/ceval/default/eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,25 +36,20 @@ def run_eval_one_time(args, evaluator, take):
subject_list = [val_file.replace("_val.csv", "") for val_file in filenames]
accuracy, summary = {}, {}

# run_date = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(time.time()))
output_dir = args.output_dir
save_result_dir = os.path.join(output_dir, f"take{take}")
if not os.path.exists(save_result_dir):
os.makedirs(save_result_dir, exist_ok=True)

all_answers = {}
for index, subject_name in enumerate(subject_list):
# print(
# f"{index/len(subject_list)} Inference starts at {run_date} on {args.model_name_or_path} with subject of {subject_name}!"
# )
val_file_path = os.path.join(val_path, f"{subject_name}_val.csv")
dev_file_path = os.path.join(dev_path, f"{subject_name}_dev.csv")
test_file_path = os.path.join(test_path, f"{subject_name}_test.csv")

val_df = pd.read_csv(val_file_path) if args.do_test is False else pd.read_csv(test_file_path)
dev_df = pd.read_csv(dev_file_path) if args.few_shot else None

# import pdb;pdb.set_trace()
correct_ratio, answers = evaluator.eval_subject(
subject_name,
val_df,
Expand Down
2 changes: 1 addition & 1 deletion llm/experimental/ceval/default/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion llm/experimental/ceval/default/model_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
8 changes: 1 addition & 7 deletions llm/experimental/layers/cache_kv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -166,13 +166,7 @@ def forward(

def _smooth(self, x, y, use_smooth_x):
# For ShiftSmooth
# smooth_shape = y.shape[2:]
self.dtype = y.dtype
# if not hasattr(self, "smooth_weight"):
# self.smooth_weight = self.create_parameter(
# shape=smooth_shape,
# attr=ParamAttr(initializer=Constant(value=1.)),
# dtype=self.dtype)
smooth_y = y
smooth_y = paddle.divide(smooth_y, self.smooth_weight)

Expand Down
1 change: 0 additions & 1 deletion llm/experimental/layers/custom_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def forward(
**kwargs
):
"""forward"""
# import pdb;pdb.set_trace()
if self.enable_fake_quant:
self.collect_kv_quant_policy(q, k, v, **kwargs)
perm = [0, 2, 1, 3] # [1, 2, 0, 3] if self.sequence_parallel else [0, 2, 1, 3]
Expand Down
4 changes: 1 addition & 3 deletions llm/experimental/observer/abs_max_headwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,8 +14,6 @@

import numpy as np
import paddle

# from paddleslim.quant.observers.channel_wise import ChannelWiseObserver
from experimental.observer.channel_wise import ChannelWiseObserver
from paddle.quantization.factory import ObserverFactory

Expand Down
2 changes: 1 addition & 1 deletion llm/experimental/observer/avg_headwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
7 changes: 1 addition & 6 deletions llm/experimental/observer/channel_wise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,13 +14,8 @@

from typing import Dict

# import numpy as np
import paddle

# from paddle.quantization.factory import ObserverFactory
from experimental.layers.cache_kv import CacheKVMatMul

# from paddleslim.quant.observers.mse import MSEObserverLayer
from paddleslim.quant.observers.uniform import UniformObserver

CHANNEL_AXIS: Dict[type, int] = {
Expand Down
133 changes: 0 additions & 133 deletions llm/experimental/observer/mse.py

This file was deleted.

92 changes: 0 additions & 92 deletions llm/experimental/observer/static.py

This file was deleted.

Loading

0 comments on commit e0ba7ef

Please sign in to comment.