Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FasterTokenizer on PPMiniLM #1542

Merged
merged 19 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 22 additions & 27 deletions examples/model_compression/pp-minilm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
<a name="PP-MiniLM中文小模型"></a>

# PP-MiniLM 中文小模型
[PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP) 联合 [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) 通过模型蒸馏、剪裁、量化等级联模型压缩技术发布中文特色小模型 PP-MiniLM(6L768H) 及压缩方案,保证模型精度的同时模型推理速度达 BERT(12L768H) 的 5.4 倍,参数量相比减少 52%,模型精度在中文语言理解评测基准 CLUE 高 0.62。
[PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP) 联合 [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) 通过模型蒸馏、剪裁、量化等级联模型压缩技术发布中文特色小模型 PP-MiniLM(6L768H) 及压缩方案,保证模型精度的同时模型推理速度达 BERT(12L768H) 的 9.3 倍,参数量相比减少 52%,模型精度在中文语言理解评测基准 CLUE 高 0.62。

PP-MiniLM 压缩方案以面向预训练模型的任务无关知识蒸馏(Task-agnostic Distillation)技术、裁剪(Pruning)技术、量化(Quantization)技术为核心,使得 PP-MiniLM **又快**、**又准**、**又小**。

1. **推理速度快**: 依托 PaddleSlim 的裁剪、量化技术对 PP-MiniLM 小模型进行压缩、加速, 使得 PP-MiniLM 量化后模型 GPU 推理速度相比 BERT base 加速比高达 5.4
1. **推理速度快**: 依托 PaddleSlim 的裁剪、量化技术对 PP-MiniLM 小模型进行压缩、加速, 使得 PP-MiniLM 量化后模型 GPU 推理速度相比 BERT base 加速比高达 9.3

2. **精度高**: 我们以 [MiniLMv2](https://arxiv.org/abs/2012.15828) 提出的 Multi-Head Self-Attention Relation Distillation 技术为基础,通过引入样本间关系知识蒸馏做了进一步算法优化,6 层 PP-MiniLM 模型在 CLUE 数据集上比 12 层 `bert-base-chinese` 高 0.62%,比同等规模的 TinyBERT、UER-py RoBERTa 分别高 2.57%、2.24%;

Expand All @@ -43,13 +43,13 @@ PP-MiniLM 压缩方案以面向预训练模型的任务无关知识蒸馏(Task-a
| Model | #Params | #FLOPs | Speedup | AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | CLUEWSC2020 | CSL | CLUE 平均值 |
| ----------------------- | ------- | ------ | ------- | ----- | ----- | ------- | ----- | ----- | ----------- | ----- | ----------- |
| BERT<sub>base</sub> | 102.3M | 10.87B | 1.00x | 74.14 | 56.81 | 61.10 | 81.19 | 74.85 | 79.93 | 81.47 | 72.78 |
| TinyBERT<sub>6</sub> | 59.7M | 5.44B | 1.88x | 72.59 | 55.70 | 57.64 | 79.57 | 73.97 | 76.32 | 80.00 | 70.83 |
| UER-py RoBERTa L6- H768 | 59.7M | 5.44B | 1.88x | 69.62 | 66.45 | 59.91 | 76.89 | 71.36 | 71.05 | 82.87 | 71.16 |
| RBT6, Chinese | 59.7M | 5.44B | 1.88x | 73.93 | 56.63 | 59.79 | 79.28 | 73.12 | 77.30 | 80.80 | 71.55 |
| ERNIE-Tiny | 90.7M | 4.83B | 2.22x | 71.55 | 58.34 | 61.41 | 76.81 | 71.46 | 72.04 | 79.13 | 70.11 |
| PP-MiniLM 6L-768H | 59.7M | 5.44B | 1.88x | 74.14 | 57.43 | 61.75 | 81.01 | 76.17 | 86.18 | 79.17 | 73.69 |
| PP-MiniLM 裁剪后 | 49.1M | 4.08B | 2.39x | 73.91 | 57.44 | 61.64 | 81.10 | 75.59 | 85.86 | 78.53 | 73.44 |
| PP-MiniLM 量化后 | 49.2M | - | 5.35x | 74.00 | 57.37 | 61.33 | 81.09 | 75.56 | 85.85 | 78.57 | 73.40 |
| TinyBERT<sub>6</sub> | 59.7M | 5.44B | 2.04x | 72.59 | 55.70 | 57.64 | 79.57 | 73.97 | 76.32 | 80.00 | 70.83 |
| UER-py RoBERTa L6-H768 | 59.7M | 5.44B | 2.04x | 69.62 | 66.45 | 59.91 | 76.89 | 71.36 | 71.05 | 82.87 | 71.16 |
| RBT6, Chinese | 59.7M | 5.44B | 2.04x | 73.93 | 56.63 | 59.79 | 79.28 | 73.12 | 77.30 | 80.80 | 71.55 |
| ERNIE-Tiny | 90.7M | 4.83B | 2.30x | 71.55 | 58.34 | 61.41 | 76.81 | 71.46 | 72.04 | 79.13 | 70.11 |
| PP-MiniLM 6L-768H | 59.7M | 5.44B | 2.12x | 74.14 | 57.43 | 61.75 | 81.01 | 76.17 | 86.18 | 79.17 | 73.69 |
| PP-MiniLM 裁剪后 | 49.1M | 4.08B | 2.60x | 73.91 | 57.44 | 61.64 | 81.10 | 75.59 | 85.86 | 78.53 | 73.44 |
| PP-MiniLM 裁剪 + 量化后 | 49.2M | - | 9.26x | 74.00 | 57.37 | 61.33 | 81.09 | 75.56 | 85.85 | 78.57 | 73.40 |


**NOTE:**
Expand Down Expand Up @@ -127,7 +127,7 @@ PP-MiniLM 是一个 6 层的预训练模型,使用 `from_pretrained`导入 PP-

PP-MiniLM 预训练小模型在 CLUE 中的 7 个分类数据集的平均精度上比 12 层 `bert-base-chinese` 高 0.62%,比同等规模的 TinyBERT、UER-py RoBERTa 分别高 2.57%、2.24%,因此我们推荐将 PP-MiniLM 运用在中文下游任务上。当然,如果想对已有模型进一步压缩,也可以参考这里的压缩方案,因为压缩方案是通用的。

本案例中会以 CLUE 中 7 个分类数据集为例介绍如何在下游任务上使用 PP-MiniLM。首先用 CLUE 中的数据集对预训练小模型 PP-MiniLM 进行微调,然后提供了一套压缩方案,即借助 [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) 进行裁剪和量化,进一步对模型规模进行压缩,最终使用基于 TensorRT 的 [Paddle Inference](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/inference_cn.html) 预测库对量化后的模型进行预测部署。裁剪、量化前,6 层 PP-MiniLM 的推理速度达 `bert-base-chinese` 的 1.9 倍,在下游任务上压缩完成后,模型推理速度高达`bert-base-chinese`的 5.4 倍。
本案例中会以 CLUE 中 7 个分类数据集为例介绍如何在下游任务上使用 PP-MiniLM。首先用 CLUE 中的数据集对预训练小模型 PP-MiniLM 进行微调,然后提供了一套压缩方案,即借助 [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) 进行裁剪和量化,进一步对模型规模进行压缩,最终使用基于 TensorRT 的 [Paddle Inference](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/inference_cn.html) 预测库对量化后的模型进行预测部署。裁剪、量化前,6 层 PP-MiniLM 的推理速度达 `bert-base-chinese` 的 1.9 倍,在下游任务上压缩完成后,模型推理速度高达`bert-base-chinese`的 9.3 倍。

<a name="数据介绍"></a>

Expand Down Expand Up @@ -189,14 +189,9 @@ sh run_clue.sh CLUEWSC2020 1e-4 32 50 128 0 ppminilm-6l-768h

#### 导出微调后模型

如果模型经过了超参寻优,在这一步我们可以在每个任务上选择表现最好的模型进行导出
模型在训练完成之后,默认情况下参数 `--save_inference_model` 为 True,会自动保存预测模型

假设待导出的模型的地址为 `ppminilm-6l-768h/models/CLUEWSC2020/1e-4_32`,可以运行下方命令将动态图模型导出为可用于部署的静态图模型:

```shell
python export_model.py --model_type ppminilm --model_path ppminilm-6l-768h/models/CLUEWSC2020/1e-4_32 --output_path fine_tuned_infer_model/float
cd ..
```
静态图(部署)模型路径与动态图模型的路径相同,文件名为 `inference.pdmodel` , `inference.pdiparams` 和 `inference.pdiparams.info` 。

<a name="裁剪"></a>

Expand Down Expand Up @@ -321,7 +316,7 @@ cd ..

#### 环境要求

这一步依赖安装有预测库的 PaddlePaddle 2.2.1。可以在 [PaddlePaddle 官网](https://paddleinference.paddlepaddle.org.cn/user_guides/download_lib.html#python) 根据机器环境选择合适的 Python 预测库进行安装。
这一步依赖安装有预测库的 PaddlePaddle 2.2.2。可以在 [PaddlePaddle 官网](https://paddleinference.paddlepaddle.org.cn/user_guides/download_lib.html#python) 根据机器环境选择合适的 Python 预测库进行安装。

想要得到更明显的加速效果,推荐在 NVIDA Tensor Core GPU(如 T4、A10、A100)上进行测试,本案例基于 T4 测试。若在 V 系列 GPU 卡上测试,由于其不支持 Int8 Tensor Core,加速效果将达不到本文档表格中的效果。

Expand Down Expand Up @@ -367,19 +362,19 @@ python infer.py --task_name ${task} --model_path $MODEL_PATH --use_trt

```shell

sh infer_perf.sh
bash infer_perf.sh
cd ..
```

下表后三行分别是微调后的模型、裁剪后的模型、量化后模型的总耗时情况。
取 5 个非 `--collect_shape` 阶段打印出的时长取平均,可以发现借助 PaddleSlim 裁剪、量化后的模型是原 BERT<sub>base</sub>模型推理速度的 5.4 倍,其中裁剪后的模型是 BERT<sub>base</sub>推理速度的 2.4 倍。

| | 平均耗时 (s) | 加速比 |
| ------------------- | ----------- | ------ |
| BERT<sub>base</sub> | 18.73 | - |
| PP-MiniLM | 9.99 | 1.88x |
| PP-MiniLM裁剪后 | 7.84 | 2.39x |
| PP-MiniLM量化后 | 3.50 | 5.35x |
取 5 个非 `--collect_shape` 阶段打印出的时长取平均,可以发现借助 PaddleSlim 裁剪、量化后的模型是原 BERT<sub>base</sub>模型推理速度的 9.3 倍,其中裁剪后的模型是 BERT<sub>base</sub>推理速度的 2.6 倍。

| | 平均耗时(s) | 加速比 |
| ----------------------- | ----------- | ------ |
| BERT<sub>base</sub> | 19.4549112 | 1.00x |
| PP-MiniLM | 9.10495186 | 2.12x |
| PP-MiniLM 裁剪后 | 7.45042658 | 2.60x |
| PP-MiniLM 裁剪 + 量化后 | 2.1215384 | 9.26x |


<a name="参考文献"></a>
Expand Down
40 changes: 22 additions & 18 deletions examples/model_compression/pp-minilm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np

import numpy as np
from paddle.metric import Metric, Accuracy

from paddlenlp.transformers import PPMiniLMForSequenceClassification, PPMiniLMTokenizer
Expand All @@ -35,29 +35,25 @@


def convert_example(example,
tokenizer,
label_list,
tokenizer=None,
is_test=False,
max_seq_length=512,
is_test=False):
**kwargs):
"""convert a glue example into necessary features"""
if not is_test:
# `label_list == None` is for regression task
label_dtype = "int64" if label_list else "float32"
# Get the label
label = example['label']
label = np.array([label], dtype=label_dtype)
example['label'] = np.array(example["label"], dtype="int64")
# Convert raw text to feature
if 'sentence' in example:
example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
elif 'sentence1' in example:
example = tokenizer(
example['sentence1'],
text_pair=example['sentence2'],
max_seq_len=max_seq_length)
elif 'keyword' in example: # CSL
if 'keyword' in example: # CSL
sentence1 = " ".join(example['keyword'])
example = tokenizer(
sentence1, text_pair=example['abst'], max_seq_len=max_seq_length)
example = {
'sentence1': sentence1,
'sentence2': example['abst'],
'label': example['label']
}
elif 'target' in example: # wsc
text, query, pronoun, query_idx, pronoun_idx = example['text'], example[
'target']['span1_text'], example['target']['span2_text'], example[
Expand All @@ -78,9 +74,17 @@ def convert_example(example,
text_list.insert(query_idx + 2, "_")
text_list.insert(query_idx + len(query) + 2 + 1, "_")
text = "".join(text_list)
example = tokenizer(text, max_seq_len=max_seq_length)

example['sentence'] = text
if tokenizer is None:
return example
if 'sentence' in example:
example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
elif 'sentence1' in example:
example = tokenizer(
example['sentence1'],
text_pair=example['sentence2'],
max_seq_len=max_seq_length)
if not is_test:
return example['input_ids'], example['token_type_ids'], label
return example['input_ids'], example['token_type_ids'], example['label']
else:
return example['input_ids'], example['token_type_ids']
78 changes: 0 additions & 78 deletions examples/model_compression/pp-minilm/finetuning/export_model.py

This file was deleted.

36 changes: 21 additions & 15 deletions examples/model_compression/pp-minilm/finetuning/run_clue.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ def parse_args():
type=distutils.util.strtobool,
default=False,
help="Whether do train.")
parser.add_argument(
"--save_inference_model",
type=distutils.util.strtobool,
default=True,
help="Whether to save inference model.")
parser.add_argument(
"--save_inference_model_with_tokenizer",
type=distutils.util.strtobool,
default=True,
help="Whether to save inference model with tokenizer.")
parser.add_argument(
"--max_steps",
default=-1,
Expand Down Expand Up @@ -202,8 +212,8 @@ def do_eval(args):
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=dev_ds.label_list,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length)

dev_ds = dev_ds.map(trans_func, lazy=True)
Expand Down Expand Up @@ -264,8 +274,8 @@ def do_train(args):

trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=train_ds.label_list,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length)

train_ds = train_ds.map(trans_func, lazy=True)
Expand Down Expand Up @@ -377,18 +387,13 @@ def do_train(args):
def export_model(args):
save_path = os.path.join(args.output_dir, "inference")
model = PPMiniLMForSequenceClassification.from_pretrained(args.output_dir)
model.eval()
# convert to static graph with specific input description
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None, None], dtype="int64"), # input_ids
paddle.static.InputSpec(
shape=[None, None], dtype="int64") # segment_ids
])
# save converted static graph model
paddle.jit.save(model, save_path)
is_text_pair = True
if args.task_name in ('tnews', 'iflytek', 'cluewsc2020'):
is_text_pair = False
model.to_static(
save_path,
use_faster_tokenizer=args.save_inference_model_with_tokenizer,
is_text_pair=is_text_pair)


def print_arguments(args):
Expand All @@ -404,6 +409,7 @@ def print_arguments(args):
print_arguments(args)
if args.do_train:
do_train(args)
export_model(args)
if args.save_inference_model:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要 save_inference_model 和 save_inference_model_with_tokenizer 2 个命令行参数?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经把save_inference_model_with_tokenizer去掉了,之前的use_faster_tokenizer可以发挥这样的功能

export_model(args)
if args.do_eval:
do_eval(args)
Loading