diff --git a/examples/model_compression/PP-MiniLM/README.md b/examples/model_compression/PP-MiniLM/README.md
new file mode 100644
index 000000000000..67f750f14ff6
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/README.md
@@ -0,0 +1,303 @@
+# PP-MiniLM中文特色小模型
+
+PP-MiniLM 中文特色小模型案例旨在提供训推一体的高精度、高性能小模型解决方案。
+
+当前解决方案依托业界领先的 Task Agnostic 模型蒸馏技术、裁剪技术、量化技术,使得小模型兼具推理速度快、模型效果好、参数规模小的 3 大特点。
+
+- 推理速度快:我们集成了 PaddleSlim 的裁剪、量化技术进一步对小模型进行压缩,保证模型推理速度达到原先的2.18倍;
+
+- 精度高: 我们以 MiniLMv2 提出的 Multi-Head Self-Attention Relation Distillation 技术为基础,通过引入样本间关系知识蒸馏做了进一步算法优化。我们的6层、hidden size为768的模型,在CLUE上的平均准确率分别高于TinyBERT、UER-py RoBERTa同样大小的模型2.66%、1.51%。
+
+- 参数规模小:依托 PaddleSlim 裁剪技术,在精度几乎无损(-0.15)条件下将模型宽度压缩 1/4。
+
+整体效果一览:
+
+| Model | #Params | #FLOPs | Speedup | AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE平均值 |
+| ----------------------- | ------------- | ------ | ------- | ----- | ----- | ------- | ----- | ----- | ----- | ----- | ---------- |
+| Bert-base | 102.3M | | 1.00x | 74.17 | 57.17 | 61.14 | | 75.08 | 80.26 | 81.47 | |
+| TinyBERT6 | 59.7M | | 1.64x | 72.22 | 55.82 | 58.10 | 79.53 | 74.00 | 75.99 | 80.57 | 70.89 |
+| UER-py RoBERTa L6- H768 | 59.7M | | 1.64x | 69.74 | 66.36 | 59.95 | 77.00 | 71.39 | 71.05 | 82.83 | 71.19 |
+| ERNIE-Tiny | 90.7M | | 1.76x | 70.78 | 55.70 | 59.95 | 75.40 | 70.98 | 67.43 | 76.60 | 68.12 |
+| PP-MiniLM 6L-768H | 59.7M | | 1.64x | 74.28 | 57.33 | 61.72 | 81.06 | 76.2 | 86.51 | 78.77 | 73.70 |
+| PP-MiniLM裁剪后 | 49.1M (+裁剪) | | 1.88x | 73.82 | 57.33 | 61.60 | 81.38 | 76.20 | 85.52 | 79.00 | 73.55 |
+| PP-MiniLM量化后 | 49.2M(+量化) | | 3.56x | 73.61 | 57.18 | 61.49 | 81.26 | 76.31 | 84.54 | 77.67 | 73.15 |
+| | | | | | | | | | | | |
+
+
+
+方案流程一览:
+
+
+
+整体流程图
+
+
+以下是本范例模型的简要目录结构及说明:
+
+```shell
+.
+├── general_ditill # 通用蒸馏目录
+│ └── general_distill.py # 通用蒸馏脚本
+│ └── run.sh # 通用蒸馏启动脚本
+├── finetuning # 下游任务训练目录
+│ └── run_clue.py # clue上的微调脚本
+│ └── run_clue.sh # clue上的微调启动脚本
+│ └── run_one_search.sh # 单数据集下精调脚本
+│ └── run_all_search.sh # clue数据集下精调脚本
+│ └── export_model.sh # 导出部署模型脚本
+├── ofa # ofa裁剪、蒸馏目录
+│ └── run_ofa.py # ofa裁剪、蒸馏脚本
+│ └── run_ofa.sh # ofa裁剪、蒸馏启动脚本
+│ └── export_model.py # 导出ofa训练得到的子模型
+├── quantization # 离线量化目录
+│ └── quant_post.py # 离线量化脚本
+│ └── quant.sh # 离线量化脚本
+├── inference # 预测目录
+│ └── infer.py # 预测脚本
+│ └── infer.sh # 预测启动脚本
+│ └── infer_all.sh # 批量预测量化模型启动脚本
+└── README # 文档,本文件
+
+```
+
+## 通用蒸馏(可选)
+
+### 环境要求
+
+本实验基于NVIDIA Tesla V100 32G 8卡进行,训练周期约为2-3天。若资源有限,可以直接下载这一步得到的模型跳过此步骤,直接使用链接的模型用下游任务数据进行微调。
+
+
+### 原理介绍
+
+PP-MiniLM模型的蒸馏方法介绍:
+
+本方案在MiniLMv2提出的Multi-Head Self-Attention Relation Distillation蒸馏算法的基础上,通过引入样本间关系知识蒸馏做了进一步算法优化。即用24层的Roberta-wwm-ext-large教师模型的第20层对6层学生模型PP-MiniLM第6层的Q-Q、K-K、V-V之间的样本间关系进行蒸馏。首先将学生、教师用于蒸馏的层上的head数进行统一,然后将Q、K、V的shape均转置成[seq_len, head_num, batch_size, head_dim],再对Q-Q、K-K、V-V之间的关系进行蒸馏。这种方法比使用原始MiniLMv2算法在CLUE上平均准确率高0.36。
+
+
+### 数据介绍
+
+将数据分割成64个文件,放在目录dataset下。
+
+
+### 运行方式
+
+```shell
+cd general_distill
+sh run.sh # 包含general_distill.py的运行配置
+cd ..
+```
+
+其中general_distill.py参数释义如下:
+
+- `model_type` 指示了学生模型类型,当前仅支持'ernie'、'roberta'。
+- `num_relation_heads` relation heads的个数,一般对于large size的教师模型是64,对于base size的教师模型是48。
+- `teacher_model_type`指示了教师模型类型,当前仅支持'ernie'、'roberta'。
+- `teacher_layer_index`蒸馏时使用的教师模型的层数
+- `student_layer_index` 蒸馏时使用的学生模型的层数
+- `teacher_model_name_or_path`教师模型的名称,例如`'roberta-wwm-ext-large'`
+- `max_seq_length` 最大的样本长度
+- `num_layers` 学生模型的层数,目前仅支持2,4,6
+- `logging_steps` 日志间隔
+- `max_steps` 最大迭代次数
+- `warmup_steps` 学习率增长得到`learning_rate`所需要的步数
+- `save_steps`保存模型的间隔步数
+- `weight_decay` 表示AdamW优化器中使用的weight_decay的系数。
+- `output_dir`训练相关文件以及模型保存的输出路径
+- `device`设备选择,推荐使用gpu
+- `input_dir` 训练数据目录
+- `use_amp` 是否使用混合精度训练,默认False
+- `alpha`head间关系的权重,默认0.0
+- `beta`样本间关系的权重,默认0.0
+
+将最终得到的模型绝对路径保存至`$GENERAL_MODEL_DIR`,例如:
+
+```shell
+GENERAL_MODEL_DIR=PaddleNLP/examples/model_compression/PP-MiniLM/general_distill/pretrain/model_400000
+```
+
+## 在下游任务上Fine-tuning
+
+### 数据介绍
+
+本实验基于 CLUE 数据集,运行 Fine-tune 脚本会自动下载该数据集到 `~/.paddlenlp/datasets/Clue/` 目录(linux下)。
+
+使用以下超参范围对第一步通用蒸馏得到的通用模型`GENERAL_MODEL_DIR`进行精调:
+
+- batch sizes: 16, 32, 64
+- learning rates: 3e-5, 5e-5, 1e-4
+
+### 启动方式
+
+基于如下超参范围对第一步蒸馏产出的小模型 `GENERAL_MODEL_DIR` 进行 Grid Search 超参寻优:
+
+```shell
+cd finetuning
+sh run_all_search.sh $GENERAL_MODEL_DIR
+```
+
+如果只是单个数据集上特定batch_size、learning_rate的微调,可以参考:
+
+```
+sh run_clue.sh CLUEWSC2020 1e-4 32 3 128 0 $GENERAL_MODEL_DIR
+```
+
+其中每个参数依次表示:clue中的任务名称、learning_rate、batch_size、epochs、max_seq_len、card id
+
+### 模型精度
+
+经过精调后,CLUE上各个任务上的精度如下表:
+
+| AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE平均值 |
+| ----- | ----- | ------- | ----- | ----- | ----- | ----- | ---------- |
+| 74.28 | 57.33 | 61.72 | 81.06 | 76.20 | 86.51 | 78.77 | 73.70 |
+
+### 你可以这样导出Fine-tuning之后的模型直接用于部署
+
+假设fine-tuning之后的模型保存的地址为`$GENERAL_MODEL_DIR/models/CLUEWSC2020/1e-4_32`,可以运行下方命令对动态图模型导出为可用于部署的静态图模型:
+
+```shell
+python export_model.py --model_type ernie --model_path $GENERAL_MODEL_DIR/models/CLUEWSC2020/1e-4_32 --output_path fine_tuned_infer_model/float
+```
+
+## 使用PaddleSlim OFA对任务上的模型进行裁剪
+
+这一步主要使用PaddleSlim ofa功能对下游任务上的模型宽度进行裁剪。如果执行这部分内容,需要安装paddleslim的最新包:
+
+```shell
+pip install -U paddleslim -i https://pypi.org/simple
+cd ofa
+```
+该过程会以finetuning后得到的模型当作教师模型,蒸馏宽度为3/4的学生模型。经过我们的实验,在6L768H 条件下,模型宽度压缩为原来的 3/4, 精度几乎无损(-0.15)。
+
+### 裁剪、蒸馏过程的启动脚本
+
+假设需要对上一步finetuning得到的模型`$GENERAL_MODEL_DIR/models/CLUEWSC2020/1e-4_32`进行进一步的裁剪,其中learning_rate、batch_size可以继续使用fine-tuning时的参数,例如:可以使用如下命令:
+
+```shell
+sh run_ofa.sh CLUEWSC2020 5e-5 16 50 128 4 ../general_distill/ernie-batchbatch-50w_400000/models/CLUEWSC2020/1e-4_32/
+```
+
+执行完成后,模型保存的路径位于`ofa_models/CLUEWSC2020/0.75/best_model/`。
+
+### 导出裁剪后的模型:
+
+这一步可以同时得到动态图、静态图的模型参数
+
+以CLUEWSC2020数据集为例,导出模型:
+
+```shell
+MODEL_PATH=ofa_models
+TASK_NAME=CLUEWSC2020
+sh export.sh $MODEL_PATH $TASK_NAME
+```
+
+或者可以批量导出各个任务上的模型:
+
+```shell
+sh export_all.sh
+```
+
+最终模型保存的位置位于` ${MODEL_PATH}/${TASK_NAME}/0.75/sub_static/float`
+
+```shell
+cd ..
+```
+
+### 模型精度
+
+经过裁剪、蒸馏后,CLUE上各个任务上的精度如下表所示。相比起裁剪前,CLUE数据集上平均值下降0.15。模型的参数量由59.7M到49.1M。
+
+| AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE平均值 |
+| ----- | ----- | ------- | ----- | ----- | ----- | ----- | ---------- |
+| 74.28 | 57.33 | 61.60 | 81.38 | 76.20 | 85.52 | 79.00 | 73.55 |
+
+
+
+## 借助PaddleSlim的量化功能进一步减少模型大小
+
+```shell
+cd quantization
+```
+
+离线量化的介绍:
+
+这一步我们可以将float32的模型通过paddleslim提供的离线量化API,无需再次训练,直接得到量化的模型。这一步使用了mse、avg、abs_max、hist多种策略,并使用4、8两种量化时的校准集数量。
+
+运行如下的脚本可以得到
+
+```shell
+python quant_post.py --task_name $TASK_NAME --input_dir ${MODEL_DIR}/${TASK_NAME}/0.75/sub_static
+```
+
+可以批量对所有数据集下的float模型进行量化:
+
+```shell
+sh quant_all.sh
+```
+
+```
+cd ..
+```
+
+### 模型精度
+
+再经过量化后,CLUE上各个任务上的精度如下表,比上一步(裁剪后)下降了0.4:
+
+| AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE平均值 |
+| ----- | ----- | ------- | ----- | ----- | ----- | ----- | ---------- |
+| 73.61 | 57.18 | 61.49 | 81.26 | 76.31 | 84.54 | 77.67 | 73.15 |
+
+## 利用Paddle Inference进行预测部署
+
+### 环境要求:
+
+这一步需要依赖paddle2.2.1中Paddle Inference进行预测,如果需要得到更明显的加速效果,推荐在NVIDA Tensor Core GPU(如T4、A10、A100)上进行测试,本案例基于T4测试。若在V系列卡上测试,由于其不支持Int8 Tensor Core,加速效果将达不到本文档表格中的加速效果。
+
+由于开启了动态shape功能,因此需要设置获取shape的范围。Paddle Inference提供了相应的接口,即首先通过离线输入数据来统计出所有临时tensor的shape范围,TRT子图的tensor输入shape范围可直接根据上一步tune出来的结果来设置,即可完成自动shape范围设置。统计完成后,只需设置统计结果路径,即可启用tuned_dynamic_shape功能。在本案例中,只需要先设置--collect_shape参数,运行infer.py,然后再取消传入这个参数,再次运行infer.py。例如:
+
+INT8预测脚本:
+
+```shell
+cd inference
+
+python infer.py --task_name ${task} --model_path ../quantization/${task}_quant_models/${algo}${bs}/int8 --int8 --use_trt --collect_shape # 生成shape range info文件
+python infer.py --task_name ${task} --model_path ../quantization/${task}_quant_models/${algo}${bs}/int8 --int8 --use_trt # load shape range info文件进行预测
+```
+如果想要批量对Int8模型进行预测并比较不同量化模型的效果,可以使用如下的脚本批量预测:
+
+```shell
+sh infer_all.sh
+```
+
+FP32预测脚本:
+
+```shell
+python infer.py --task_name ${task} --model_path $MODEL_PATH --use_trt --collect_shape
+python infer.py --task_name ${task} --model_path $MODEL_PATH --use_trt
+```
+
+### 性能测试
+本案例是在NVIDIA Tesla T4 单卡上,cuda11.1、cudnn8.1、TensorRT7.2,使用inference/infer.py脚本,对量化后的模型进行预测。
+
+测试性能时采用了TNEWS数据集下的模型,下表三行分别是微调后的模型、OFA裁剪蒸馏后的模型、量化方法为mse、校准集数量为4的量化模型,计算dev上预测的总耗时(去除前20个steps)。
+
+可以发现借助PaddleSlim裁剪、量化后的模型比原BERT-base模型推理速度加速255.86%,其中裁剪可以加速87.98%。
+
+| | 平均耗时(s) | 加速比 |
+| ------------------ | ----------- | ------- |
+| BERT | 20.64 | 0 |
+| FP32 | 12.61 | 63.68% |
+| FP32+裁剪 | 10.98 | 87.98% |
+| FP32+裁剪+INT8量化 | 5.80 | 255.86% |
+
+
+INT8预测脚本:
+
+```shell
+
+sh infer.sh
+```
+
+```shell
+cd ..
+```
diff --git a/examples/model_compression/PP-MiniLM/finetuning/export_model.py b/examples/model_compression/PP-MiniLM/finetuning/export_model.py
new file mode 100644
index 000000000000..1e1e4fe459a6
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/finetuning/export_model.py
@@ -0,0 +1,78 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+import paddle
+
+from run_clue import MODEL_CLASSES
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ # Required parameters
+ parser.add_argument(
+ "--model_type",
+ default=None,
+ type=str,
+ required=True,
+ help="Model type selected in the list: " +
+ ", ".join(MODEL_CLASSES.keys()), )
+ parser.add_argument(
+ "--model_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path of the trained model to be exported.", )
+ parser.add_argument(
+ "--output_path",
+ default=None,
+ type=str,
+ required=True,
+ help="The output file prefix used to save the exported inference model.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ args.model_type = args.model_type.lower()
+ model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
+
+ # build model and load trained parameters
+ model = model_class.from_pretrained(args.model_path)
+ # switch to eval model
+ model.eval()
+ # convert to static graph with specific input description
+ model = paddle.jit.to_static(
+ model,
+ input_spec=[
+ paddle.static.InputSpec(
+ shape=[None, None], dtype="int64"), # input_ids
+ paddle.static.InputSpec(
+ shape=[None, None], dtype="int64") # segment_ids
+ ])
+ # save converted static graph model
+ paddle.jit.save(model, args.output_path)
+ # also save tokenizer for inference usage
+ tokenizer = tokenizer_class.from_pretrained(args.model_path)
+ tokenizer.save_pretrained(os.path.dirname(args.output_path))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/model_compression/PP-MiniLM/finetuning/run_all_search.sh b/examples/model_compression/PP-MiniLM/finetuning/run_all_search.sh
new file mode 100644
index 000000000000..f19ad7bfaaa5
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/finetuning/run_all_search.sh
@@ -0,0 +1,35 @@
+# $1 means GENERAL_DIR
+
+# The penultimate parameter is the card id, this script can be changed if necessary
+bash run_one_search.sh $1 afqmc 0 &
+bash run_one_search.sh $1 tnews 1 &
+bash run_one_search.sh $1 ifly 2 &
+bash run_one_search.sh $1 ocnli 3 &
+bash run_one_search.sh $1 csl 4 &
+bash run_one_search.sh $1 wsc 5 &
+
+# Because the CMNLI data set is significantly larger than other data sets,
+# it needs to be placed on different cards.
+lr=1e-4
+bs=16
+sh run_clue.sh CMNLI $lr $bs 3 128 0 $1 > $1/cmnli/${lr}_${bs}_3_128.log &
+bs=32
+sh run_clue.sh CMNLI $lr $bs 3 128 1 $1 > $1/cmnli/${lr}_${bs}_3_128.log &
+bs=64
+sh run_clue.sh CMNLI $lr $bs 3 128 2 $1 > $1/cmnli/${lr}_${bs}_3_128.log &
+
+lr=5e-5
+bs=16
+sh run_clue.sh CMNLI $lr $bs 3 128 3 $1 > $1/cmnli/${lr}_${bs}_3_128.log &
+bs=32
+sh run_clue.sh CMNLI $lr $bs 3 128 4 $1 > $1/cmnli/${lr}_${bs}_3_128.log &
+bs=64
+sh run_clue.sh CMNLI $lr $bs 3 128 5 $1 > $1/cmnli/${lr}_${bs}_3_128.log &
+
+lr=3e-5
+bs=16
+sh run_clue.sh CMNLI $lr $bs 3 128 6 $1 > $1/cmnli/${lr}_${bs}_3_128.log &
+bs=32
+sh run_clue.sh CMNLI $lr $bs 3 128 5 $1 > $1/cmnli/${lr}_${bs}_3_128.log &
+bs=64
+sh run_clue.sh CMNLI $lr $bs 3 128 7 $1 > $1/cmnli/${lr}_${bs}_3_128.log &
diff --git a/examples/model_compression/PP-MiniLM/finetuning/run_clue.py b/examples/model_compression/PP-MiniLM/finetuning/run_clue.py
new file mode 100644
index 000000000000..3c3cc18b61ef
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/finetuning/run_clue.py
@@ -0,0 +1,450 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import os
+import sys
+import random
+import time
+import math
+import distutils.util
+from functools import partial
+
+import numpy as np
+import paddle
+from paddle.io import DataLoader
+import paddle.nn as nn
+from paddle.metric import Accuracy
+
+from paddlenlp.datasets import load_dataset
+from paddlenlp.data import Stack, Tuple, Pad, Dict
+from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer, BertModel
+from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
+from paddlenlp.transformers import LinearDecayWithWarmup
+
+FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT)
+logger = logging.getLogger(__name__)
+
+METRIC_CLASSES = {
+ "afqmc": Accuracy,
+ "tnews": Accuracy,
+ "iflytek": Accuracy,
+ "ocnli": Accuracy,
+ "cmnli": Accuracy,
+ "cluewsc2020": Accuracy,
+ "csl": Accuracy,
+}
+
+MODEL_CLASSES = {
+ "bert": (BertForSequenceClassification, BertTokenizer),
+ "ernie": (ErnieForSequenceClassification, ErnieTokenizer),
+}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ # Required parameters
+ parser.add_argument(
+ "--task_name",
+ default=None,
+ type=str,
+ required=True,
+ help="The name of the task to train selected in the list: " +
+ ", ".join(METRIC_CLASSES.keys()), )
+ parser.add_argument(
+ "--model_type",
+ default=None,
+ type=str,
+ required=True,
+ help="Model type selected in the list: " +
+ ", ".join(MODEL_CLASSES.keys()), )
+ parser.add_argument(
+ "--model_name_or_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to pre-trained model or shortcut name selected in the list: "
+ + ", ".join(
+ sum([
+ list(classes[-1].pretrained_init_configuration.keys())
+ for classes in MODEL_CLASSES.values()
+ ], [])), )
+ parser.add_argument(
+ "--output_dir",
+ default="best_clue_model",
+ type=str,
+ required=True,
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--max_seq_length",
+ default=128,
+ type=int,
+ help="The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded.", )
+ parser.add_argument(
+ "--learning_rate",
+ default=1e-4,
+ type=float,
+ help="The initial learning rate for Adam.")
+ parser.add_argument(
+ "--num_train_epochs",
+ default=3,
+ type=int,
+ help="Total number of training epochs to perform.", )
+ parser.add_argument(
+ "--logging_steps",
+ type=int,
+ default=100,
+ help="Log every X updates steps.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=100,
+ help="Save checkpoint every X updates steps.")
+ parser.add_argument(
+ "--batch_size",
+ default=32,
+ type=int,
+ help="Batch size per GPU/CPU for training.", )
+ parser.add_argument(
+ "--weight_decay",
+ default=0.0,
+ type=float,
+ help="Weight decay if we apply some.")
+ parser.add_argument(
+ "--warmup_steps",
+ default=0,
+ type=int,
+ help="Linear warmup over warmup_steps. If > 0: Override warmup_proportion"
+ )
+ parser.add_argument(
+ "--warmup_proportion",
+ default=0.1,
+ type=float,
+ help="Linear warmup proportion over total steps.")
+ parser.add_argument(
+ "--adam_epsilon",
+ default=1e-6,
+ type=float,
+ help="Epsilon for Adam optimizer.")
+ parser.add_argument(
+ "--do_train",
+ type=distutils.util.strtobool,
+ default=True,
+ help="Whether do train.")
+ parser.add_argument(
+ "--max_steps",
+ default=-1,
+ type=int,
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
+ )
+ parser.add_argument(
+ "--seed", default=42, type=int, help="random seed for initialization")
+ parser.add_argument(
+ "--device",
+ default="gpu",
+ type=str,
+ help="The device to select to train the model, is must be cpu/gpu/xpu.")
+ parser.add_argument(
+ "--max_grad_norm",
+ default=1.0,
+ type=float,
+ help="The max value of grad norm.")
+ args = parser.parse_args()
+ return args
+
+
+def set_seed(args):
+ # Use the same data seed(for data shuffle) for all procs to guarantee data
+ # consistency after sharding.
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ # Maybe different op seeds(for dropout) for different procs is better. By:
+ # `paddle.seed(args.seed + paddle.distributed.get_rank())`
+ paddle.seed(args.seed)
+
+
+@paddle.no_grad()
+def evaluate(model, loss_fct, metric, data_loader):
+ model.eval()
+ metric.reset()
+ for batch in data_loader:
+ input_ids, segment_ids, labels = batch
+ logits = model(input_ids, segment_ids)
+ loss = loss_fct(logits, labels)
+ correct = metric.compute(logits, labels)
+ metric.update(correct)
+ res = metric.accumulate()
+ print("eval loss: %f, acc: %s, " % (loss.numpy(), res), end='')
+ model.train()
+ return res
+
+
+def convert_example(example,
+ tokenizer,
+ label_list,
+ max_seq_length=512,
+ is_test=False):
+ """convert a glue example into necessary features"""
+ if not is_test:
+ # `label_list == None` is for regression task
+ label_dtype = "int64" if label_list else "float32"
+ # Get the label
+ label = example['label']
+ label = np.array([label], dtype=label_dtype)
+ # Convert raw text to feature
+ if 'sentence' in example:
+ example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
+ elif 'sentence1' in example:
+ example = tokenizer(
+ example['sentence1'],
+ text_pair=example['sentence2'],
+ max_seq_len=max_seq_length)
+ elif 'keyword' in example: # CSL
+ sentence1 = " ".join(example['keyword'])
+ example = tokenizer(
+ sentence1, text_pair=example['abst'], max_seq_len=max_seq_length)
+ elif 'target' in example: # wsc
+ text, query, pronoun, query_idx, pronoun_idx = example['text'], example[
+ 'target']['span1_text'], example['target']['span2_text'], example[
+ 'target']['span1_index'], example['target']['span2_index']
+ text_list = list(text)
+ assert text[pronoun_idx:(pronoun_idx + len(pronoun)
+ )] == pronoun, "pronoun: {}".format(pronoun)
+ assert text[query_idx:(query_idx + len(query)
+ )] == query, "query: {}".format(query)
+ if pronoun_idx > query_idx:
+ text_list.insert(query_idx, "_")
+ text_list.insert(query_idx + len(query) + 1, "_")
+ text_list.insert(pronoun_idx + 2, "[")
+ text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]")
+ else:
+ text_list.insert(pronoun_idx, "[")
+ text_list.insert(pronoun_idx + len(pronoun) + 1, "]")
+ text_list.insert(query_idx + 2, "_")
+ text_list.insert(query_idx + len(query) + 2 + 1, "_")
+ text = "".join(text_list)
+ example = tokenizer(text, max_seq_len=max_seq_length)
+ if not is_test:
+ return example['input_ids'], example['token_type_ids'], label
+ else:
+ return example['input_ids'], example['token_type_ids']
+
+
+def do_eval(args):
+ paddle.set_device(args.device)
+ if paddle.distributed.get_world_size() > 1:
+ paddle.distributed.init_parallel_env()
+
+ set_seed(args)
+
+ args.task_name = args.task_name.lower()
+ metric_class = METRIC_CLASSES[args.task_name]
+ args.model_type = args.model_type.lower()
+ model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
+
+ dev_ds = load_dataset('clue', args.task_name, splits='dev')
+
+ tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
+
+ trans_func = partial(
+ convert_example,
+ tokenizer=tokenizer,
+ label_list=dev_ds.label_list,
+ max_seq_length=args.max_seq_length)
+
+ dev_ds = dev_ds.map(trans_func, lazy=True)
+ dev_batch_sampler = paddle.io.BatchSampler(
+ dev_ds, batch_size=args.batch_size, shuffle=False)
+
+ batchify_fn = lambda samples, fn=Tuple(
+ Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
+ Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
+ Stack(dtype="int64" if dev_ds.label_list else "float32") # label
+ ): fn(samples)
+
+ dev_data_loader = DataLoader(
+ dataset=dev_ds,
+ batch_sampler=dev_batch_sampler,
+ collate_fn=batchify_fn,
+ num_workers=0,
+ return_list=True)
+
+ num_classes = 1 if dev_ds.label_list == None else len(dev_ds.label_list)
+ model = model_class.from_pretrained(
+ args.model_name_or_path, num_classes=num_classes)
+ if paddle.distributed.get_world_size() > 1:
+ model = paddle.DataParallel(model)
+
+ metric = metric_class()
+ best_acc = 0.0
+ global_step = 0
+ tic_train = time.time()
+ model.eval()
+ metric.reset()
+ for batch in dev_data_loader:
+ input_ids, segment_ids, labels = batch
+ logits = model(input_ids, segment_ids)
+ correct = metric.compute(logits, labels)
+ metric.update(correct)
+ res = metric.accumulate()
+ print("acc: %s\n, " % (res), end='')
+
+
+def do_train(args):
+ paddle.set_device(args.device)
+ if paddle.distributed.get_world_size() > 1:
+ paddle.distributed.init_parallel_env()
+
+ set_seed(args)
+
+ args.task_name = args.task_name.lower()
+ metric_class = METRIC_CLASSES[args.task_name]
+ args.model_type = args.model_type.lower()
+ model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
+
+ train_ds = load_dataset('clue', args.task_name, splits='train')
+
+ tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
+
+ trans_func = partial(
+ convert_example,
+ tokenizer=tokenizer,
+ label_list=train_ds.label_list,
+ max_seq_length=args.max_seq_length)
+ train_ds = train_ds.map(trans_func, lazy=True)
+ train_batch_sampler = paddle.io.DistributedBatchSampler(
+ train_ds, batch_size=args.batch_size, shuffle=True)
+ batchify_fn = lambda samples, fn=Tuple(
+ Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
+ Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
+ Stack(dtype="int64" if train_ds.label_list else "float32") # label
+ ): fn(samples)
+ train_data_loader = DataLoader(
+ dataset=train_ds,
+ batch_sampler=train_batch_sampler,
+ collate_fn=batchify_fn,
+ num_workers=0,
+ return_list=True)
+
+ dev_ds = load_dataset('clue', args.task_name, splits='dev')
+
+ dev_ds = dev_ds.map(trans_func, lazy=True)
+ dev_batch_sampler = paddle.io.BatchSampler(
+ dev_ds, batch_size=args.batch_size, shuffle=False)
+ dev_data_loader = DataLoader(
+ dataset=dev_ds,
+ batch_sampler=dev_batch_sampler,
+ collate_fn=batchify_fn,
+ num_workers=0,
+ return_list=True)
+
+ num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list)
+ model = model_class.from_pretrained(
+ args.model_name_or_path, num_classes=num_classes)
+ if paddle.distributed.get_world_size() > 1:
+ model = paddle.DataParallel(model)
+
+ if args.max_steps > 0:
+ num_training_steps = args.max_steps
+ num_train_epochs = math.ceil(num_training_steps /
+ len(train_data_loader))
+ else:
+ num_training_steps = len(train_data_loader) * args.num_train_epochs
+ num_train_epochs = args.num_train_epochs
+
+ warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion
+
+ lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
+ warmup)
+
+ # Generate parameter names needed to perform weight decay.
+ # All bias and LayerNorm parameters are excluded.
+ decay_params = [
+ p.name for n, p in model.named_parameters()
+ if not any(nd in n for nd in ["bias", "norm"])
+ ]
+ optimizer = paddle.optimizer.AdamW(
+ learning_rate=lr_scheduler,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=args.adam_epsilon,
+ parameters=model.parameters(),
+ weight_decay=args.weight_decay,
+ apply_decay_param_fun=lambda x: x in decay_params,
+ grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm))
+
+ loss_fct = paddle.nn.loss.CrossEntropyLoss(
+ ) if train_ds.label_list else paddle.nn.loss.MSELoss()
+
+ metric = metric_class()
+ best_acc = 0.0
+ global_step = 0
+ tic_train = time.time()
+ for epoch in range(num_train_epochs):
+ for step, batch in enumerate(train_data_loader):
+ global_step += 1
+ input_ids, segment_ids, labels = batch
+ logits = model(input_ids, segment_ids)
+ loss = loss_fct(logits, labels)
+ loss.backward()
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.clear_grad()
+ if global_step % args.logging_steps == 0:
+ print(
+ "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"
+ % (global_step, num_training_steps, epoch, step,
+ paddle.distributed.get_rank(), loss, optimizer.get_lr(),
+ args.logging_steps / (time.time() - tic_train)))
+ tic_train = time.time()
+ if global_step % args.save_steps == 0 or global_step == num_training_steps:
+ tic_eval = time.time()
+ acc = evaluate(model, loss_fct, metric, dev_data_loader)
+ print("eval done total : %s s" % (time.time() - tic_eval))
+ if acc > best_acc:
+ best_acc = acc
+ output_dir = args.output_dir
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ #Need better way to get inner model of DataParallel
+ model_to_save = model._layers if isinstance(
+ model, paddle.DataParallel) else model
+ model_to_save.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ if global_step >= num_training_steps:
+ print("best_acc: ", best_acc)
+ return
+ print("best_acc: ", best_acc)
+
+
+def print_arguments(args):
+ """print arguments"""
+ print('----------- Configuration Arguments -----------')
+ for arg, value in sorted(vars(args).items()):
+ print('%s: %s' % (arg, value))
+ print('------------------------------------------------')
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ print_arguments(args)
+ if args.do_train:
+ do_train(args)
+ else:
+ do_eval(args)
diff --git a/examples/model_compression/PP-MiniLM/finetuning/run_clue.sh b/examples/model_compression/PP-MiniLM/finetuning/run_clue.sh
new file mode 100644
index 000000000000..ad74187f5a4b
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/finetuning/run_clue.sh
@@ -0,0 +1,25 @@
+
+export TASK_NAME=$1
+export LR=$2
+export BS=$3
+export EPOCH=$4
+export MAX_SEQ_LEN=$5
+export CUDA_VISIBLE_DEVICES=$6
+export MODEL_PATH=$7
+
+python -u ./run_clue.py \
+ --model_type ernie \
+ --model_name_or_path ${MODEL_PATH} \
+ --task_name ${TASK_NAME} \
+ --max_seq_length ${MAX_SEQ_LEN} \
+ --batch_size ${BS} \
+ --learning_rate ${LR} \
+ --num_train_epochs ${EPOCH} \
+ --logging_steps 100 \
+ --seed 42 \
+ --save_steps 100 \
+ --warmup_proportion 0.1 \
+ --weight_decay 0.01 \
+ --adam_epsilon 1e-8 \
+ --output_dir ${MODEL_PATH}/models/${TASK_NAME}/${LR}_${BS}/ \
+ --device gpu \
diff --git a/examples/model_compression/PP-MiniLM/finetuning/run_one_search.sh b/examples/model_compression/PP-MiniLM/finetuning/run_one_search.sh
new file mode 100644
index 000000000000..ef0871d71302
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/finetuning/run_one_search.sh
@@ -0,0 +1,54 @@
+OUTPUT_DIR=$1
+TASK_NAME=$2
+
+mkdir ${OUTPUT_DIR}/afqmc
+mkdir ${OUTPUT_DIR}/tnews
+mkdir ${OUTPUT_DIR}/ifly
+mkdir ${OUTPUT_DIR}/ocnli
+mkdir ${OUTPUT_DIR}/wsc
+mkdir ${OUTPUT_DIR}/csl
+mkdir ${OUTPUT_DIR}/cmnli
+
+
+for lr in 1e-4 5e-5 3e-5
+do
+ for bs in 16 32 64
+ do
+ echo bs: $bs, lr: $lr
+
+ if [ $TASK_NAME == afqmc ]
+ then
+ sh run_clue.sh AFQMC $lr $bs 3 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/afqmc/${lr}_${bs}_3_128.log
+ fi
+
+ if [ $TASK_NAME == tnews ]
+ then
+ sh run_clue.sh TNEWS $lr $bs 3 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/tnews/${lr}_${bs}_3_128.log
+ fi
+
+ if [ $TASK_NAME == ifly ]
+ then
+ sh run_clue.sh IFLYTEK $lr $bs 6 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/ifly/${lr}_${bs}_6_128.log
+ fi
+
+ if [ $TASK_NAME == ocnli ]
+ then
+ sh run_clue.sh OCNLI $lr $bs 6 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/ocnli/${lr}_${bs}_6_128.log
+ fi
+
+ if [ $TASK_NAME == wsc ]
+ then
+ sh run_clue.sh CLUEWSC2020 $lr $bs 50 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/wsc/${lr}_${bs}_50_128.log
+ fi
+
+ if [ $TASK_NAME == csl ]
+ then
+ sh run_clue.sh CSL $lr $bs 8 256 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/csl/${lr}_${bs}_8_256.log
+ fi
+
+ if [ $TASK_NAME == cmnli ]
+ then
+ sh run_clue.sh CMNLI $lr $bs 3 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/cmnli/${lr}_${bs}_3_128.log
+ fi
+done
+
diff --git a/examples/model_compression/PP-MiniLM/general_distill/general_distill.py b/examples/model_compression/PP-MiniLM/general_distill/general_distill.py
new file mode 100644
index 000000000000..df26030eeb02
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/general_distill/general_distill.py
@@ -0,0 +1,498 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import random
+import time
+from functools import partial
+from concurrent.futures import ThreadPoolExecutor
+import distutils.util
+import math
+
+import numpy as np
+import paddle
+from paddle.io import DataLoader
+import paddle.nn.functional as F
+from paddle import tensor
+
+from paddlenlp.utils.log import logger
+from paddlenlp.data import Tuple, Pad
+from paddlenlp.utils.tools import TimeCostAverage
+from paddlenlp.transformers import LinearDecayWithWarmup
+from paddlenlp.transformers import RobertaModel, RobertaTokenizer
+from paddlenlp.transformers import ErnieModel, ErnieForSequenceClassification, ErnieTokenizer
+from paddlenlp.transformers.distill_utils import to_distill, calc_minilm_loss_multi_relation
+
+MODEL_CLASSES = {
+ "roberta": (RobertaModel, RobertaTokenizer),
+ "ernie": (ErnieForSequenceClassification, ErnieTokenizer)
+}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ # Required parameters
+ parser.add_argument(
+ "--model_type",
+ default="ernie",
+ type=str,
+ required=True,
+ help="Model type selected in the list: " +
+ ", ".join(MODEL_CLASSES.keys()), )
+ parser.add_argument(
+ "--teacher_model_type",
+ default="ernie",
+ type=str,
+ required=True,
+ help="Model type selected in the list: " +
+ ", ".join(MODEL_CLASSES.keys()), )
+ parser.add_argument(
+ "--student_model_name_or_path",
+ default=None,
+ type=str,
+ required=False,
+ help="Path to pre-trained model or shortcut name selected in the list: "
+ + ", ".join(
+ sum([
+ list(classes[-1].pretrained_init_configuration.keys())
+ for classes in MODEL_CLASSES.values()
+ ], [])), )
+ parser.add_argument(
+ "--teacher_model_name_or_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to pre-trained model.")
+ parser.add_argument(
+ "--input_dir",
+ default=None,
+ type=str,
+ required=True,
+ help="The input directory where the data will be read from.", )
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ required=True,
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+
+ parser.add_argument(
+ "--max_seq_length",
+ default=128,
+ type=int,
+ help="The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded.", )
+ parser.add_argument(
+ "--learning_rate",
+ default=6e-4,
+ type=float,
+ help="The initial learning rate for Adam.")
+ parser.add_argument(
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument(
+ "--num_layers",
+ default=6,
+ type=int,
+ help="Number layers of student model.", )
+ parser.add_argument(
+ "--teacher_layer_index",
+ default=19,
+ type=int,
+ help="Total number of training epochs to perform.", )
+ parser.add_argument(
+ "--student_layer_index",
+ default=5,
+ type=int,
+ help="Total number of training epochs to perform.", )
+ parser.add_argument(
+ "--num_train_epochs",
+ default=3,
+ type=int,
+ help="Total number of training epochs to perform.", )
+ parser.add_argument(
+ "--logging_steps",
+ type=int,
+ default=100,
+ help="Log every X updates steps.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=100,
+ help="Save checkpoint every X updates steps.")
+ parser.add_argument(
+ "--batch_size",
+ default=512,
+ type=int,
+ help="Batch size per GPU/CPU for training.", )
+ parser.add_argument(
+ "--num_relation_heads",
+ default=64,
+ type=int,
+ help="The number of relation heads is 48 and 64 for base and large-size teacher model.",
+ )
+ parser.add_argument("--beta", default=0.0, type=float, help="0.0 usually")
+ parser.add_argument("--alpha", default=0.0, type=float, help="0.0 usually")
+ parser.add_argument(
+ "--weight_decay",
+ default=0.01,
+ type=float,
+ help="Weight decay if we apply some.")
+ parser.add_argument(
+ "--warmup_steps",
+ default=-1,
+ type=int,
+ help="Linear warmup over warmup_steps. If > 0: Override warmup_proportion"
+ )
+ parser.add_argument(
+ "--warmup_proportion",
+ default=0.01,
+ type=float,
+ help="Linear warmup proportion over total steps.")
+ parser.add_argument(
+ "--adam_epsilon",
+ default=1e-8,
+ type=float,
+ help="Epsilon for Adam optimizer.")
+ parser.add_argument(
+ "--max_steps",
+ default=400000,
+ type=int,
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
+ )
+ parser.add_argument(
+ "--seed", default=42, type=int, help="random seed for initialization")
+ parser.add_argument(
+ "--device",
+ default="gpu",
+ type=str,
+ help="The device to select to train the model, is must be cpu/gpu/xpu.")
+ parser.add_argument(
+ "--use_amp",
+ type=distutils.util.strtobool,
+ default=False,
+ help="Enable mixed precision training.")
+ parser.add_argument(
+ "--scale_loss",
+ type=float,
+ default=2**15,
+ help="The value of scale_loss for fp16.")
+ args = parser.parse_args()
+ return args
+
+
+def set_seed(args):
+ random.seed(args.seed + paddle.distributed.get_rank())
+ np.random.seed(args.seed + paddle.distributed.get_rank())
+ paddle.seed(args.seed + paddle.distributed.get_rank())
+
+
+class WorkerInitObj(object):
+ def __init__(self, seed):
+ self.seed = seed
+
+ def __call__(self, id):
+ np.random.seed(seed=self.seed + id)
+ random.seed(self.seed + id)
+
+
+def create_pretraining_dataset(input_file, shared_list, args, worker_init,
+ tokenizer):
+ train_data = PretrainingDataset(
+ input_file=input_file,
+ tokenizer=tokenizer,
+ max_seq_length=args.max_seq_length)
+ # files have been sharded, no need to dispatch again
+ train_batch_sampler = paddle.io.BatchSampler(
+ train_data, batch_size=args.batch_size, shuffle=True)
+
+ # DataLoader cannot be pickled because of its place.
+ # If it can be pickled, use global function instead of lambda and use
+ # ProcessPoolExecutor instead of ThreadPoolExecutor to prefetch.
+ batchify_fn = lambda samples, fn=Tuple(
+ Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
+ ): fn(samples)
+
+ train_data_loader = DataLoader(
+ dataset=train_data,
+ batch_sampler=train_batch_sampler,
+ collate_fn=batchify_fn,
+ num_workers=0,
+ worker_init_fn=worker_init,
+ return_list=True)
+ return train_data_loader, input_file
+
+
+class PretrainingDataset(paddle.io.Dataset):
+ def __init__(self, input_file, tokenizer, max_seq_length):
+ self.input_file = input_file
+ f = open(input_file, 'r')
+ input_ids = []
+ for i, line in enumerate(f):
+ line = line[:max_seq_length]
+ tokenized_example = tokenizer(line, max_seq_len=max_seq_length)
+ input_ids.append(tokenized_example['input_ids'])
+ self.inputs = np.asarray(input_ids)
+ f.close()
+
+ def __len__(self):
+ 'Denotes the total number of samples'
+ return len(self.inputs)
+
+ def __getitem__(self, index):
+ input_ids = [np.asarray(self.inputs[index])]
+ return input_ids
+
+
+def do_train(args):
+ paddle.set_device(args.device)
+ if paddle.distributed.get_world_size() > 1:
+ paddle.distributed.init_parallel_env()
+
+ set_seed(args)
+
+ worker_init = WorkerInitObj(args.seed + paddle.distributed.get_rank())
+ args.model_type = args.model_type.lower()
+
+ # For teacher
+ teacher_model_class, tokenizer_class = MODEL_CLASSES[
+ args.teacher_model_type]
+ tokenizer = tokenizer_class.from_pretrained(args.teacher_model_name_or_path)
+
+ # For student
+ model_class, _ = MODEL_CLASSES[args.model_type]
+ if args.num_layers == 6:
+ ernie = ErnieModel(
+ vocab_size=tokenizer.vocab_size,
+ num_hidden_layers=6,
+ hidden_act='relu',
+ intermediate_size=3072,
+ hidden_size=768) # layer: 6
+ elif args.num_layers == 4:
+ ernie = ErnieModel(
+ vocab_size=tokenizer.vocab_size,
+ num_hidden_layers=4,
+ hidden_act='relu',
+ intermediate_size=1024,
+ hidden_size=256,
+ num_attention_heads=16) # layer: 4
+ else:
+ ernie = ErnieModel(
+ vocab_size=tokenizer.vocab_size,
+ num_hidden_layers=2,
+ hidden_act='relu',
+ hidden_size=128,
+ intermediate_size=512) # layer: 2
+ student = model_class(ernie)
+
+ teacher = teacher_model_class.from_pretrained(
+ args.teacher_model_name_or_path)
+ pad_token_id = 0
+
+ if paddle.distributed.get_world_size() > 1:
+ student = paddle.DataParallel(student, find_unused_parameters=True)
+ teacher = paddle.DataParallel(teacher, find_unused_parameters=True)
+
+ num_training_steps = args.max_steps
+
+ warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion
+ lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
+ warmup)
+
+ # Generate parameter names needed to perform weight decay.
+ # All bias and LayerNorm parameters are excluded.
+ decay_params = [
+ p.name for n, p in student.named_parameters()
+ if not any(nd in n for nd in ["bias", "norm"])
+ ]
+
+ optimizer = paddle.optimizer.AdamW(
+ learning_rate=lr_scheduler,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=args.adam_epsilon,
+ parameters=student.parameters(),
+ weight_decay=args.weight_decay,
+ apply_decay_param_fun=lambda x: x in decay_params,
+ grad_clip=paddle.nn.ClipGradByGlobalNorm(args.max_grad_norm))
+
+ if args.use_amp:
+ scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
+ pool = ThreadPoolExecutor(1)
+
+ teacher = to_distill(
+ teacher, return_qkv=True, layer_index=args.teacher_layer_index)
+ student = to_distill(
+ student, return_qkv=True, layer_index=args.student_layer_index)
+
+ global_step = 0
+ tic_train = time.time()
+ for epoch in range(args.num_train_epochs):
+ files = [
+ os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
+ if os.path.isfile(os.path.join(args.input_dir, f))
+ ]
+ files.sort()
+ num_files = len(files)
+ random.Random(args.seed + epoch).shuffle(files)
+ f_start_id = 0
+
+ shared_file_list = {}
+
+ if paddle.distributed.get_world_size() > num_files:
+ remainder = paddle.distributed.get_world_size() % num_files
+
+ data_file = files[(
+ f_start_id * paddle.distributed.get_world_size() +
+ paddle.distributed.get_rank() + remainder * f_start_id) %
+ num_files]
+ else:
+ data_file = files[(f_start_id * paddle.distributed.get_world_size()
+ + paddle.distributed.get_rank()) % num_files]
+
+ previous_file = data_file
+
+ train_data_loader, _ = create_pretraining_dataset(
+ data_file, shared_file_list, args, worker_init, tokenizer)
+
+ # TODO(guosheng): better way to process single file
+ single_file = True if f_start_id + 1 == len(files) else False
+
+ for f_id in range(f_start_id, len(files)):
+ if not single_file and f_id == f_start_id:
+ continue
+ if paddle.distributed.get_world_size() > num_files:
+ data_file = files[(
+ f_id * paddle.distributed.get_world_size() +
+ paddle.distributed.get_rank() + remainder * f_id) %
+ num_files]
+ else:
+ data_file = files[(f_id * paddle.distributed.get_world_size() +
+ paddle.distributed.get_rank()) % num_files]
+ previous_file = data_file
+ dataset_future = pool.submit(create_pretraining_dataset, data_file,
+ shared_file_list, args, worker_init,
+ tokenizer)
+
+ kl_loss_fct = paddle.nn.KLDivLoss('sum')
+ train_cost_avg = TimeCostAverage()
+ total_samples = 0
+ batch_start = time.time()
+ for step, batch in enumerate(train_data_loader):
+ global_step += 1
+ input_ids = batch[0]
+ attention_mask = paddle.unsqueeze(
+ (input_ids == pad_token_id
+ ).astype(paddle.get_default_dtype()) * -1e9,
+ axis=[1, 2])
+ with paddle.amp.auto_cast(
+ args.use_amp,
+ custom_white_list=["layer_norm", "gelu", "softmax"]):
+ student(input_ids)
+ with paddle.no_grad():
+ teacher(input_ids)
+ # Q-Q relation
+ q_t, q_s = teacher.outputs.q, student.outputs.q
+ batch_size = q_t.shape[0]
+ pad_seq_len = q_t.shape[2]
+ loss_qr1, loss_qr2, loss_qr3 = calc_minilm_loss_multi_relation(
+ kl_loss_fct, q_s, q_t, attention_mask,
+ args.num_relation_heads, args.alpha, args.beta)
+ del q_t, q_s
+ # K-K relation
+ k_t, k_s = teacher.outputs.k, student.outputs.k
+ loss_kr1, loss_kr2, loss_kr3 = calc_minilm_loss_multi_relation(
+ kl_loss_fct, k_s, k_t, attention_mask,
+ args.num_relation_heads, args.alpha, args.beta)
+ del k_t, k_s
+
+ # V-V relation
+ v_t, v_s = teacher.outputs.v, student.outputs.v
+ loss_vr1, loss_vr2, loss_vr3 = calc_minilm_loss_multi_relation(
+ kl_loss_fct, v_s, v_t, attention_mask,
+ args.num_relation_heads, args.alpha, args.beta)
+
+ del v_t, v_s
+
+ loss1 = (loss_qr1 + loss_kr1 + loss_vr1)
+ loss1 /= args.num_relation_heads * pad_seq_len * batch_size
+
+ loss2 = loss_qr2 + loss_kr2 + loss_vr2
+ loss2 /= args.num_relation_heads * pad_seq_len * batch_size
+
+ loss3 = loss_qr3 + loss_kr3 + loss_vr3
+ loss3 /= args.num_relation_heads * pad_seq_len * batch_size
+ loss = (1 - args.alpha - args.beta
+ ) * loss1 + loss2 * args.alpha + loss3 * args.beta
+
+ if args.use_amp:
+ scaler.scale(loss).backward()
+ scaler.minimize(optimizer, loss)
+ else:
+ loss.backward()
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.clear_grad()
+
+ total_samples += args.batch_size
+ train_run_cost = time.time() - batch_start
+ train_cost_avg.record(train_run_cost)
+ if global_step % args.logging_steps == 0:
+ logger.info(
+ "global step: %d, epoch: %d, batch: %d, loss: %f, loss1: %f, loss2: %f, loss3: %f,"
+ "lr: %f, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
+ % (global_step, epoch, step, loss, loss1, loss2, loss3,
+ optimizer.get_lr(), train_cost_avg.get_average(),
+ total_samples / args.logging_steps, total_samples /
+ (args.logging_steps * train_cost_avg.get_average())))
+ total_samples = 0
+ train_cost_avg.reset()
+ if global_step % args.save_steps == 0 or global_step == num_training_steps:
+ if paddle.distributed.get_rank() == 0:
+ output_dir = os.path.join(args.output_dir,
+ "model_%d" % global_step)
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ # need better way to get inner model of DataParallel
+ model_to_save = student._layers if isinstance(
+ student, paddle.DataParallel) else student
+ model_to_save.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ paddle.save(
+ optimizer.state_dict(),
+ os.path.join(output_dir, "model_state.pdopt"))
+ if global_step >= args.max_steps:
+ del train_data_loader
+ return
+ batch_start = time.time()
+
+ del train_data_loader
+ train_data_loader, data_file = dataset_future.result(timeout=None)
+
+
+def print_arguments(args):
+ """print arguments"""
+ print('----------- Configuration Arguments -----------')
+ for arg, value in sorted(vars(args).items()):
+ print('%s: %s' % (arg, value))
+ print('------------------------------------------------')
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ print_arguments(args)
+ do_train(args)
diff --git a/examples/model_compression/PP-MiniLM/general_distill/run.sh b/examples/model_compression/PP-MiniLM/general_distill/run.sh
new file mode 100644
index 000000000000..12ae63f7971b
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/general_distill/run.sh
@@ -0,0 +1,55 @@
+set -eux
+
+unset CUDA_VISIBLE_DEVICES
+
+bs=128
+maxlen=128
+numH=64
+lr=6e-4
+maxStep=400000
+warmStep=4000
+wd=1e-2
+
+teacher=roberta
+teacherModel=roberta-wwm-ext-large
+
+alpha=0
+beta=1.0
+mode=hardest
+use_amp=True
+teacher_layer_index=19
+student_layer_index=5
+num_layers=6
+
+hp_config=bs${bs}_maxlen${maxlen}_lr${lr}_wd${wd}_numH${numH}_maxStep${maxStep}_warmStep${warmStep}_adamW_maxnorm1p0_teacher_${teacherModel}_coldboot_teacher_vocab_index${teacher_layer_index}_4l-312d-batchbatch
+
+export PYTHONPATH=../../../../:$PYTHONPATH
+output_dir="./pretrain_${hp_config}"
+
+mkdir -p ${output_dir}
+cp ./general_distill.py ${output_dir}/
+cp ../../../../paddlenlp/transformers/distill_utils.py ${output_dir}/
+
+
+python3 -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" general_distill.py \
+ --model_type ernie \
+ --num_relation_heads ${numH} \
+ --teacher_model_type ${teacher} \
+ --teacher_layer_index ${teacher_layer_index} \
+ --student_layer_index ${student_layer_index} \
+ --teacher_model_name_or_path ${teacherModel} \
+ --max_seq_length ${maxlen} \
+ --num_layers ${num_layers} \
+ --batch_size ${bs} \
+ --learning_rate ${lr} \
+ --logging_steps 20 \
+ --max_steps ${maxStep} \
+ --warmup_steps ${warmStep} \
+ --save_steps 20000 \
+ --weight_decay ${wd} \
+ --output_dir ${output_dir} \
+ --device gpu \
+ --input_dir dataset/ \
+ --use_amp ${use_amp} \
+ --alpha ${alpha} \
+ --beta ${beta} \
diff --git a/examples/model_compression/PP-MiniLM/inference/infer.py b/examples/model_compression/PP-MiniLM/inference/infer.py
new file mode 100644
index 000000000000..9cc8422119a5
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/inference/infer.py
@@ -0,0 +1,311 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import time
+from functools import partial
+import numpy as np
+
+import paddle
+from paddle import inference
+from paddlenlp.datasets import load_dataset
+from paddlenlp.data import Stack, Tuple, Pad
+from paddle.metric import Metric, Accuracy, Precision, Recall
+from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
+
+from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
+
+METRIC_CLASSES = {
+ "afqmc": Accuracy,
+ "tnews": Accuracy,
+ "iflytek": Accuracy,
+ "ocnli": Accuracy,
+ "cmnli": Accuracy,
+ "cluewsc2020": Accuracy,
+ "csl": Accuracy,
+}
+
+MODEL_CLASSES = {"ernie": (ErnieForSequenceClassification, ErnieTokenizer), }
+
+
+def convert_example(example,
+ tokenizer,
+ label_list,
+ max_seq_length=512,
+ is_test=False):
+ """convert a glue example into necessary features"""
+ if not is_test:
+ # `label_list == None` is for regression task
+ label_dtype = "int64" if label_list else "float32"
+ # Get the label
+ label = example['label']
+ label = np.array([label], dtype=label_dtype)
+ # Convert raw text to feature
+ if 'sentence' in example:
+ example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
+ elif 'sentence1' in example:
+ example = tokenizer(
+ example['sentence1'],
+ text_pair=example['sentence2'],
+ max_seq_len=max_seq_length)
+ elif 'keyword' in example: # CSL
+ sentence1 = " ".join(example['keyword'])
+ example = tokenizer(
+ sentence1, text_pair=example['abst'], max_seq_len=max_seq_length)
+ elif 'target' in example: # wsc
+ text, query, pronoun, query_idx, pronoun_idx = example['text'], example[
+ 'target']['span1_text'], example['target']['span2_text'], example[
+ 'target']['span1_index'], example['target']['span2_index']
+ text_list = list(text)
+ assert text[pronoun_idx:(pronoun_idx + len(pronoun)
+ )] == pronoun, "pronoun: {}".format(pronoun)
+ assert text[query_idx:(query_idx + len(query)
+ )] == query, "query: {}".format(query)
+ if pronoun_idx > query_idx:
+ text_list.insert(query_idx, "_")
+ text_list.insert(query_idx + len(query) + 1, "_")
+ text_list.insert(pronoun_idx + 2, "[")
+ text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]")
+ else:
+ text_list.insert(pronoun_idx, "[")
+ text_list.insert(pronoun_idx + len(pronoun) + 1, "]")
+ text_list.insert(query_idx + 2, "_")
+ text_list.insert(query_idx + len(query) + 2 + 1, "_")
+ text = "".join(text_list)
+ example = tokenizer(text, max_seq_len=max_seq_length)
+
+ if not is_test:
+ return example['input_ids'], example['token_type_ids'], label
+ else:
+ return example['input_ids'], example['token_type_ids']
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ # Required parameters
+ parser.add_argument(
+ "--task_name",
+ default='afqmc',
+ type=str,
+ help="The name of the task to perform predict, selected in the list: " +
+ ", ".join(METRIC_CLASSES.keys()), )
+ parser.add_argument(
+ "--model_type",
+ default='ernie',
+ type=str,
+ help="Model type selected in the list: " +
+ ", ".join(MODEL_CLASSES.keys()), )
+ parser.add_argument(
+ "--tokenizer_path",
+ default='../general_distill/ernie-batchbatch-50w_400000/',
+ type=str,
+ required=True,
+ help="The directory for tokenizer.", )
+ parser.add_argument(
+ "--model_path",
+ default='./quant_models/model',
+ type=str,
+ required=True,
+ help="The path prefix of inference model to be used.", )
+ parser.add_argument(
+ "--device",
+ default="gpu",
+ choices=["gpu", "cpu", "xpu"],
+ help="Device selected for inference.", )
+ parser.add_argument(
+ "--batch_size",
+ default=32,
+ type=int,
+ help="Batch size for predict.", )
+ parser.add_argument(
+ "--max_seq_length",
+ default=128,
+ type=int,
+ help="The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded.", )
+ parser.add_argument(
+ "--use_trt",
+ action='store_true',
+ help="Whether to use inference engin TensorRT.", )
+
+ parser.add_argument(
+ "--collect_shape",
+ action='store_true',
+ help="Whether collect shape range info.", )
+ parser.add_argument(
+ "--int8",
+ action='store_true',
+ help="Whether int8 inference.", )
+ args = parser.parse_args()
+ return args
+
+
+@paddle.no_grad()
+def evaluate(outputs, metric, data_loader):
+ metric.reset()
+ for i, batch in enumerate(data_loader):
+ input_ids, segment_ids, labels = batch
+ logits = paddle.to_tensor(outputs[i][0])
+ correct = metric.compute(logits, labels)
+ metric.update(correct)
+ res = metric.accumulate()
+ print("acc: %s, " % res, end='')
+
+
+class Predictor(object):
+ def __init__(self, predictor, input_handles, output_handles):
+ self.predictor = predictor
+ self.input_handles = input_handles
+ self.output_handles = output_handles
+
+ @classmethod
+ def create_predictor(cls, args):
+ config = paddle.inference.Config(args.model_path + ".pdmodel",
+ args.model_path + ".pdiparams")
+ if args.device == "gpu":
+ # set GPU configs accordingly
+ config.enable_use_gpu(100, 0)
+ cls.device = paddle.set_device("gpu")
+ elif args.device == "cpu":
+ # set CPU configs accordingly,
+ # such as enable_mkldnn, set_cpu_math_library_num_threads
+ config.disable_gpu()
+ cls.device = paddle.set_device("cpu")
+ elif args.device == "xpu":
+ # set XPU configs accordingly
+ config.enable_xpu(100)
+ config.switch_use_feed_fetch_ops(False) # could be deleted
+ if args.use_trt:
+ if args.int8:
+ print("int8")
+ config.enable_tensorrt_engine(
+ workspace_size=1 << 30,
+ precision_mode=inference.PrecisionType.Int8,
+ max_batch_size=args.batch_size,
+ min_subgraph_size=5,
+ use_static=False,
+ use_calib_mode=False)
+ else:
+ config.enable_tensorrt_engine(
+ workspace_size=1 << 30,
+ precision_mode=inference.PrecisionType.Float32,
+ max_batch_size=args.batch_size,
+ min_subgraph_size=5,
+ use_static=False,
+ use_calib_mode=False)
+ print("Enable TensorRT is: {}".format(
+ config.tensorrt_engine_enabled()))
+ if args.collect_shape:
+ config.collect_shape_range_info(
+ os.path.join(
+ os.path.dirname(args.model_path), args.task_name +
+ '_shape_range_info.pbtxt'))
+ else:
+ config.enable_tuned_tensorrt_dynamic_shape(
+ os.path.join(
+ os.path.dirname(args.model_path),
+ args.task_name + "_shape_range_info.pbtxt"), True)
+ predictor = paddle.inference.create_predictor(config)
+ input_handles = [
+ predictor.get_input_handle(name)
+ for name in predictor.get_input_names()
+ ]
+ output_handles = [
+ predictor.get_output_handle(name)
+ for name in predictor.get_output_names()
+ ]
+ cls.time = 0.0
+
+ return cls(predictor, input_handles, output_handles)
+
+ def predict_batch(self, data):
+ for input_field, input_handle in zip(data, self.input_handles):
+ input_handle.copy_from_cpu(input_field.numpy() if isinstance(
+ input_field, paddle.Tensor) else input_field)
+ time1 = time.time()
+ self.predictor.run()
+ paddle.fluid.core._cuda_synchronize(self.device)
+ self.time += time.time() - time1
+ output = [
+ output_handle.copy_to_cpu() for output_handle in self.output_handles
+ ]
+
+ return output
+
+ def predict(self, dataset, collate_fn, args, batch_size=1):
+ metric = METRIC_CLASSES[args.task_name]()
+ batch_sampler = paddle.io.BatchSampler(
+ dataset, batch_size=batch_size, shuffle=False)
+ data_loader = paddle.io.DataLoader(
+ dataset=dataset,
+ batch_sampler=batch_sampler,
+ collate_fn=collate_fn,
+ num_workers=0,
+ return_list=True)
+ outputs = []
+ metric.reset()
+ for i, data in enumerate(data_loader):
+ # warmup for performance test
+ if i < 20:
+ continue
+ if len(data) == 2:
+ output = self.predict_batch(data)
+ else:
+ output = self.predict_batch([data[0], data[1]])
+ logits = paddle.to_tensor(output)
+ correct = metric.compute(logits, data[2])
+ metric.update(correct)
+ outputs.append(output)
+ if len(data) > 2:
+ res = metric.accumulate()
+ print("task name: %s, acc: %s, " % (args.task_name, res), end='')
+ print("time: ", self.time)
+
+ return outputs
+
+
+def main():
+ paddle.seed(42)
+ args = parse_args()
+
+ args.task_name = args.task_name.lower()
+ args.model_type = args.model_type.lower()
+
+ predictor = Predictor.create_predictor(args)
+
+ model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
+
+ dev_ds = load_dataset('clue', args.task_name, splits='dev')
+ tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path)
+ trans_func = partial(
+ convert_example,
+ tokenizer=tokenizer,
+ label_list=dev_ds.label_list,
+ max_seq_length=args.max_seq_length,
+ is_test=False)
+
+ dev_ds = dev_ds.map(trans_func, lazy=True)
+ batchify_fn = lambda samples, fn=Tuple(
+ Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
+ Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
+ Stack(dtype="int64" if dev_ds.label_list else "float32") # label
+ ): fn(samples)
+ outputs = predictor.predict(
+ dev_ds, batch_size=args.batch_size, collate_fn=batchify_fn, args=args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/model_compression/PP-MiniLM/inference/infer.sh b/examples/model_compression/PP-MiniLM/inference/infer.sh
new file mode 100644
index 000000000000..b50f23e37fd7
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/inference/infer.sh
@@ -0,0 +1,26 @@
+echo 原来的模型
+python infer.py --task_name tnews --model_path tnews/float --use_trt --collect_shape
+python infer.py --task_name tnews --model_path tnews/float --use_trt
+python infer.py --task_name tnews --model_path tnews/float --use_trt
+python infer.py --task_name tnews --model_path tnews/float --use_trt
+python infer.py --task_name tnews --model_path tnews/float --use_trt
+python infer.py --task_name tnews --model_path tnews/float --use_trt
+
+
+echo 裁剪后
+python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt --collect_shape
+python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt
+python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt
+python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt
+python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt
+python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt
+
+
+echo int8推理
+python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt --collect_shape
+python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt
+python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt
+python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt
+python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt
+python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt
+
diff --git a/examples/model_compression/PP-MiniLM/inference/infer_all.sh b/examples/model_compression/PP-MiniLM/inference/infer_all.sh
new file mode 100644
index 000000000000..ef1d9bfbfcfb
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/inference/infer_all.sh
@@ -0,0 +1,12 @@
+for task in afqmc tnews iflytek cmnli ocnli cluewsc2020 csl
+do
+ for bs in 4 8
+ do
+ for algo in abs_max avg hist mse
+ do
+ python infer.py --task_name ${task} --model_path ../quantization/${task}_quant_models/${algo}${bs}/int8 --int8 --use_trt --collect_shape
+ python infer.py --task_name ${task} --model_path ../quantization/${task}_quant_models/${algo}${bs}/int8 --int8 --use_trt
+ echo this is ${task}, ${algo}, ${bs}
+ done
+ done
+done
diff --git a/examples/model_compression/PP-MiniLM/ofa/export.sh b/examples/model_compression/PP-MiniLM/ofa/export.sh
new file mode 100644
index 000000000000..74b7c3002db4
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/ofa/export.sh
@@ -0,0 +1,7 @@
+MODEL_PATH=$1
+TASK_NAME=$2
+python export_model.py --model_type ernie \
+ --model_name_or_path ${MODEL_PATH}/${TASK_NAME}/0.75/best_model \
+ --sub_model_output_dir ${MODEL_PATH}/${TASK_NAME}/0.75/sub/ \
+ --static_sub_model ${MODEL_PATH}/${TASK_NAME}/0.75/sub_static/float \
+ --n_gpu 1 --width_mult 0.75
diff --git a/examples/model_compression/PP-MiniLM/ofa/export_all.sh b/examples/model_compression/PP-MiniLM/ofa/export_all.sh
new file mode 100644
index 000000000000..8e295fca5c05
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/ofa/export_all.sh
@@ -0,0 +1,12 @@
+MODEL_PATH=ofa_models
+
+for TASK_NAME in AFQMC TNEWS IFLYTEK CMNLI OCNLI CLUEWSC2020 CSL
+
+do
+ python export_model.py --model_type ernie \
+ --model_name_or_path ${MODEL_PATH}/${TASK_NAME}/0.75/best_model \
+ --sub_model_output_dir ${MODEL_PATH}/${TASK_NAME}/0.75/sub/ \
+ --static_sub_model ${MODEL_PATH}/${TASK_NAME}/0.75/sub_static/float \
+ --n_gpu 1 --width_mult 0.75
+
+done
diff --git a/examples/model_compression/PP-MiniLM/ofa/export_model.py b/examples/model_compression/PP-MiniLM/ofa/export_model.py
new file mode 100644
index 000000000000..c9edb13e6301
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/ofa/export_model.py
@@ -0,0 +1,226 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import os
+import math
+import random
+import time
+import json
+from functools import partial
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from paddlenlp.transformers import BertModel, BertForSequenceClassification, BertTokenizer
+from paddlenlp.transformers import ErnieModel, ErnieForSequenceClassification, ErnieTokenizer
+from paddlenlp.transformers import RobertaForSequenceClassification, RobertaTokenizer
+from paddlenlp.utils.log import logger
+from paddleslim.nas.ofa import OFA, utils
+from paddleslim.nas.ofa.convert_super import Convert, supernet
+from paddleslim.nas.ofa.layers import BaseBlock
+
+MODEL_CLASSES = {
+ "bert": (BertForSequenceClassification, BertTokenizer),
+ "ernie": (ErnieForSequenceClassification, ErnieTokenizer),
+}
+
+
+def ernie_forward(self,
+ input_ids,
+ token_type_ids=None,
+ position_ids=None,
+ attention_mask=None):
+ wtype = self.pooler.dense.fn.weight.dtype if hasattr(
+ self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype
+ if attention_mask is None:
+ attention_mask = paddle.unsqueeze(
+ (input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2])
+ embedding_output = self.embeddings(input_ids, token_type_ids, position_ids)
+ encoded_layer = self.encoder(embedding_output, attention_mask)
+ pooled_output = self.pooler(encoded_layer)
+
+ return encoded_layer, pooled_output
+
+
+ErnieModel.forward = ernie_forward
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ # Required parameters
+ parser.add_argument(
+ "--model_type",
+ default=None,
+ type=str,
+ required=True,
+ help="Model type selected in the list: " +
+ ", ".join(MODEL_CLASSES.keys()), )
+ parser.add_argument(
+ "--model_name_or_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to pre-trained model or shortcut name selected in the list: "
+ + ", ".join(
+ sum([
+ list(classes[-1].pretrained_init_configuration.keys())
+ for classes in MODEL_CLASSES.values()
+ ], [])), )
+ parser.add_argument(
+ "--sub_model_output_dir",
+ default=None,
+ type=str,
+ required=True,
+ help="The output directory where the sub model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--static_sub_model",
+ default=None,
+ type=str,
+ help="The output directory where the sub static model will be written. If set to None, not export static model",
+ )
+ parser.add_argument(
+ "--max_seq_length",
+ default=128,
+ type=int,
+ help="The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded.", )
+ parser.add_argument(
+ "--n_gpu",
+ type=int,
+ default=1,
+ help="number of gpus to use, 0 for cpu.")
+ parser.add_argument(
+ '--width_mult',
+ type=float,
+ default=1.0,
+ help="width mult you want to export")
+ parser.add_argument(
+ '--depth_mult',
+ type=float,
+ default=1.0,
+ help="depth mult you want to export")
+ args = parser.parse_args()
+ return args
+
+
+def export_static_model(model, model_path, max_seq_length):
+ input_shape = [
+ paddle.static.InputSpec(
+ shape=[None, max_seq_length], dtype='int64'),
+ paddle.static.InputSpec(
+ shape=[None, max_seq_length], dtype='int64')
+ ]
+ net = paddle.jit.to_static(model, input_spec=input_shape)
+ paddle.jit.save(net, model_path)
+
+
+def do_train(args):
+ paddle.set_device("gpu" if args.n_gpu else "cpu")
+ args.model_type = args.model_type.lower()
+ model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
+ config_path = os.path.join(args.model_name_or_path, 'model_config.json')
+ cfg_dict = dict(json.loads(open(config_path).read()))
+
+ kept_layers_index = {}
+ if args.depth_mult < 1.0:
+ depth = round(cfg_dict["init_args"][0]['num_hidden_layers'] *
+ args.depth_mult)
+ cfg_dict["init_args"][0]['num_hidden_layers'] = depth
+ for idx, i in enumerate(range(1, depth + 1)):
+ kept_layers_index[idx] = math.floor(i / args.depth_mult) - 1
+
+ os.rename(config_path, config_path + '_bak')
+ with open(config_path, "w", encoding="utf-8") as f:
+ f.write(json.dumps(cfg_dict, ensure_ascii=False))
+
+ num_labels = cfg_dict['num_classes']
+
+ model = model_class.from_pretrained(
+ args.model_name_or_path, num_classes=num_labels)
+
+ origin_model = model_class.from_pretrained(
+ args.model_name_or_path, num_classes=num_labels)
+
+ os.rename(config_path + '_bak', config_path)
+
+ sp_config = supernet(expand_ratio=[1.0, args.width_mult])
+ model = Convert(sp_config).convert(model)
+
+ ofa_model = OFA(model)
+
+ sd = paddle.load(
+ os.path.join(args.model_name_or_path, 'model_state.pdparams'))
+
+ if len(kept_layers_index) == 0:
+ ofa_model.model.set_state_dict(sd)
+ else:
+ for name, params in ofa_model.model.named_parameters():
+ if 'encoder' not in name:
+ params.set_value(sd[name])
+ else:
+ idx = int(name.strip().split('.')[3])
+ mapping_name = name.replace(
+ '.' + str(idx) + '.',
+ '.' + str(kept_layers_index[idx]) + '.')
+ params.set_value(sd[mapping_name])
+
+ best_config = utils.dynabert_config(ofa_model, args.width_mult)
+ for name, sublayer in ofa_model.model.named_sublayers():
+ if isinstance(sublayer, paddle.nn.MultiHeadAttention):
+ sublayer.num_heads = int(args.width_mult * sublayer.num_heads)
+
+ #for name, params in origin_model.named_parameters():
+ # print(name, params.name)
+ origin_model_new = ofa_model.export(
+ best_config,
+ input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]],
+ input_dtypes=['int64', 'int64'],
+ origin_model=origin_model)
+ for name, sublayer in origin_model_new.named_sublayers():
+ if isinstance(sublayer, paddle.nn.MultiHeadAttention):
+ sublayer.num_heads = int(args.width_mult * sublayer.num_heads)
+
+ output_dir = os.path.join(args.sub_model_output_dir,
+ "model_width_%.5f" % args.width_mult)
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ model_to_save = origin_model_new
+ model_to_save.save_pretrained(output_dir)
+ #print(origin_model_new.state_dict().keys())
+ #print("=====================")
+ #for name, params in origin_model_new.named_parameters():
+ # print(name, params.name)
+ if args.static_sub_model != None:
+ export_static_model(origin_model_new, args.static_sub_model,
+ args.max_seq_length)
+
+
+def print_arguments(args):
+ """print arguments"""
+ print('----------- Configuration Arguments -----------')
+ for arg, value in sorted(vars(args).items()):
+ print('%s: %s' % (arg, value))
+ print('------------------------------------------------')
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ print_arguments(args)
+ do_train(args)
diff --git a/examples/model_compression/PP-MiniLM/ofa/run_ofa.py b/examples/model_compression/PP-MiniLM/ofa/run_ofa.py
new file mode 100644
index 000000000000..94a431d3bee7
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/ofa/run_ofa.py
@@ -0,0 +1,541 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import os
+import random
+import time
+import math
+from functools import partial
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.io import DataLoader
+from paddle.metric import Accuracy, Precision, Recall
+
+from paddlenlp.data import Stack, Tuple, Pad, Dict
+from paddlenlp.datasets import load_dataset
+from paddlenlp.data.sampler import SamplerHelper
+from paddlenlp.transformers import TinyBertModel, TinyBertForSequenceClassification, TinyBertTokenizer
+from paddlenlp.transformers import LinearDecayWithWarmup
+from paddlenlp.utils.log import logger
+from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
+from paddlenlp.transformers import LinearDecayWithWarmup
+from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
+from paddlenlp.transformers import TinyBertForSequenceClassification, TinyBertTokenizer
+from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer, ErnieModel
+from paddlenlp.transformers import RobertaForSequenceClassification, RobertaTokenizer
+from paddleslim.nas.ofa import OFA, DistillConfig, utils
+from paddleslim.nas.ofa.utils import nlp_utils
+from paddleslim.nas.ofa.convert_super import Convert, supernet
+
+METRIC_CLASSES = {
+ "afqmc": Accuracy,
+ "tnews": Accuracy,
+ "iflytek": Accuracy,
+ "ocnli": Accuracy,
+ "cmnli": Accuracy,
+ "cluewsc2020": Accuracy,
+ "csl": Accuracy,
+}
+
+MODEL_CLASSES = {
+ "bert": (BertForSequenceClassification, BertTokenizer),
+ "roberta": (RobertaForSequenceClassification, RobertaTokenizer),
+ "tinybert": (TinyBertForSequenceClassification, TinyBertTokenizer),
+ "ernie": (ErnieForSequenceClassification, ErnieTokenizer),
+}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ # Required parameters
+ parser.add_argument(
+ "--task_name",
+ default=None,
+ type=str,
+ required=True,
+ help="The name of the task to train selected in the list: " +
+ ", ".join(METRIC_CLASSES.keys()), )
+ parser.add_argument(
+ "--model_type",
+ default=None,
+ type=str,
+ required=True,
+ help="Model type selected in the list: " +
+ ", ".join(MODEL_CLASSES.keys()), )
+ parser.add_argument(
+ "--model_name_or_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to pre-trained model or shortcut name selected in the list: "
+ + ", ".join(
+ sum([
+ list(classes[-1].pretrained_init_configuration.keys())
+ for classes in MODEL_CLASSES.values()
+ ], [])), )
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ required=True,
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--glue_dir",
+ default="/root/.paddlenlp/datasets/Clue/",
+ type=str,
+ required=False,
+ help="The Glue directory.", )
+ parser.add_argument(
+ "--max_seq_length",
+ default=128,
+ type=int,
+ help="The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded.", )
+ parser.add_argument(
+ "--batch_size",
+ default=8,
+ type=int,
+ help="Batch size per GPU/CPU for training.", )
+ parser.add_argument(
+ "--learning_rate",
+ default=5e-5,
+ type=float,
+ help="The initial learning rate for Adam.")
+ parser.add_argument(
+ "--weight_decay",
+ default=0.0,
+ type=float,
+ help="Weight decay if we apply some.")
+ parser.add_argument(
+ "--adam_epsilon",
+ default=1e-8,
+ type=float,
+ help="Epsilon for Adam optimizer.")
+ parser.add_argument(
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument(
+ "--lambda_logit",
+ default=1.0,
+ type=float,
+ help="lambda for logit loss.")
+ parser.add_argument(
+ "--num_train_epochs",
+ default=3,
+ type=int,
+ help="Total number of training epochs to perform.", )
+ parser.add_argument(
+ "--max_steps",
+ default=-1,
+ type=int,
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
+ )
+ parser.add_argument(
+ "--warmup_steps",
+ default=0,
+ type=int,
+ help="Linear warmup over warmup_steps.")
+ parser.add_argument(
+ "--warmup_proportion",
+ default=0.1,
+ type=float,
+ help="Linear warmup proportion over total steps.")
+ parser.add_argument(
+ "--logging_steps",
+ type=int,
+ default=500,
+ help="Log every X updates steps.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save checkpoint every X updates steps.")
+ parser.add_argument(
+ "--seed", type=int, default=42, help="random seed for initialization")
+ parser.add_argument(
+ "--device",
+ default="gpu",
+ type=str,
+ choices=["gpu", "cpu", "xpu"],
+ help="The device to select to train the model, is must be cpu/gpu/xpu.")
+ parser.add_argument(
+ '--width_mult_list',
+ nargs='+',
+ type=float,
+ default=[1.0, 5 / 6, 2 / 3, 0.5],
+ help="width mult in compress")
+ args = parser.parse_args()
+ return args
+
+
+def set_seed(args):
+ # Use the same data seed(for data shuffle) for all procs to guarantee data
+ # consistency after sharding.
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ # Maybe different op seeds(for dropout) for different procs is better. By:
+ # `paddle.seed(args.seed + paddle.distributed.get_rank())`
+ paddle.seed(args.seed)
+
+
+@paddle.no_grad()
+def evaluate(model, metric, data_loader, width_mult, student=False):
+ model.eval()
+ metric.reset()
+ for i, batch in enumerate(data_loader):
+ input_ids, segment_ids, labels = batch
+ logits = model(input_ids, segment_ids, attention_mask=[None, None])
+ #print(logits)
+ #sys.exit()
+ #if student:
+ #print(batch)
+ #print(logits)
+ #import pdb; pdb.set_trace() # ofa_model和保存下来的超网络model不对
+
+ #sys.exit()
+ if isinstance(logits, tuple):
+ logits = logits[0]
+ correct = metric.compute(logits, labels)
+ metric.update(correct)
+
+ res = metric.accumulate()
+ print("width_mult: %s, acc: %s, " % (str(width_mult), res), end='')
+ model.train()
+ return res
+
+
+### monkey patch for bert forward to accept [attention_mask, head_mask] as attention_mask
+def ernie_forward(self,
+ input_ids,
+ token_type_ids=None,
+ position_ids=None,
+ attention_mask=[None, None]):
+ wtype = self.pooler.dense.fn.weight.dtype if hasattr(
+ self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype
+ if attention_mask[0] is None:
+ attention_mask[0] = paddle.unsqueeze(
+ (input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2])
+ embedding_output = self.embeddings(input_ids, token_type_ids, position_ids)
+ encoded_layer = self.encoder(embedding_output, attention_mask)
+ pooled_output = self.pooler(encoded_layer)
+
+ return encoded_layer, pooled_output
+
+
+ErnieModel.forward = ernie_forward
+
+
+### reorder weights according head importance and neuron importance
+def reorder_neuron_head(model, head_importance, neuron_importance):
+ # reorder heads and ffn neurons
+ for layer, current_importance in enumerate(neuron_importance):
+ # reorder heads
+ idx = paddle.argsort(head_importance[layer], descending=True)
+ nlp_utils.reorder_head(model.ernie.encoder.layers[layer].self_attn, idx)
+ # reorder neurons
+ idx = paddle.argsort(
+ paddle.to_tensor(current_importance), descending=True)
+ nlp_utils.reorder_neuron(
+ model.ernie.encoder.layers[layer].linear1.fn, idx, dim=1)
+ nlp_utils.reorder_neuron(
+ model.ernie.encoder.layers[layer].linear2.fn, idx, dim=0)
+
+
+def soft_cross_entropy(inp, target):
+ inp_likelihood = F.log_softmax(inp, axis=-1)
+ target_prob = F.softmax(target, axis=-1)
+ return -1. * paddle.mean(paddle.sum(inp_likelihood * target_prob, axis=-1))
+
+
+def convert_example(example,
+ tokenizer,
+ label_list,
+ max_seq_length=512,
+ is_test=False):
+ """convert a glue example into necessary features"""
+ if not is_test:
+ # `label_list == None` is for regression task
+ label_dtype = "int64" if label_list else "float32"
+ # Get the label
+ label = example['label']
+ label = np.array([label], dtype=label_dtype)
+ # Convert raw text to feature
+ if 'sentence' in example:
+ example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
+ elif 'sentence1' in example:
+ example = tokenizer(
+ example['sentence1'],
+ text_pair=example['sentence2'],
+ max_seq_len=max_seq_length)
+ elif 'keyword' in example: # CSL
+ sentence1 = " ".join(example['keyword'])
+ example = tokenizer(
+ sentence1, text_pair=example['abst'], max_seq_len=max_seq_length)
+ elif 'target' in example: # wsc
+ text, query, pronoun, query_idx, pronoun_idx = example['text'], example[
+ 'target']['span1_text'], example['target']['span2_text'], example[
+ 'target']['span1_index'], example['target']['span2_index']
+ text_list = list(text)
+ # print(text)
+ assert text[pronoun_idx:(pronoun_idx + len(pronoun)
+ )] == pronoun, "pronoun: {}".format(pronoun)
+ assert text[query_idx:(query_idx + len(query)
+ )] == query, "query: {}".format(query)
+ if pronoun_idx > query_idx:
+ text_list.insert(query_idx, "_")
+ text_list.insert(query_idx + len(query) + 1, "_")
+ text_list.insert(pronoun_idx + 2, "[")
+ text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]")
+ else:
+ text_list.insert(pronoun_idx, "[")
+ text_list.insert(pronoun_idx + len(pronoun) + 1, "]")
+ text_list.insert(query_idx + 2, "_")
+ text_list.insert(query_idx + len(query) + 2 + 1, "_")
+ text = "".join(text_list)
+ # print(text)
+ example = tokenizer(text, max_seq_len=max_seq_length)
+
+ if not is_test:
+ return example['input_ids'], example['token_type_ids'], label
+ else:
+ return example['input_ids'], example['token_type_ids']
+
+
+def do_train(args):
+ paddle.set_device(args.device)
+ if paddle.distributed.get_world_size() > 1:
+ paddle.distributed.init_parallel_env()
+
+ set_seed(args)
+
+ args.task_name = args.task_name.lower()
+ metric_class = METRIC_CLASSES[args.task_name]
+ args.model_type = args.model_type.lower()
+ model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
+ train_ds = load_dataset('clue', args.task_name, splits='train')
+ tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
+
+ trans_func = partial(
+ convert_example,
+ tokenizer=tokenizer,
+ label_list=train_ds.label_list,
+ max_seq_length=args.max_seq_length)
+ train_ds = train_ds.map(trans_func, lazy=True)
+ train_batch_sampler = paddle.io.DistributedBatchSampler(
+ train_ds, batch_size=args.batch_size, shuffle=True)
+ batchify_fn = lambda samples, fn=Tuple(
+ Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
+ Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
+ Stack(dtype="int64" if train_ds.label_list else "float32") # label
+ ): fn(samples)
+
+ train_data_loader = DataLoader(
+ dataset=train_ds,
+ batch_sampler=train_batch_sampler,
+ collate_fn=batchify_fn,
+ num_workers=0,
+ return_list=True)
+
+ dev_ds = load_dataset('clue', args.task_name, splits='dev')
+ dev_ds = dev_ds.map(trans_func, lazy=True)
+ dev_batch_sampler = paddle.io.BatchSampler(
+ dev_ds, batch_size=args.batch_size, shuffle=False)
+ dev_data_loader = DataLoader(
+ dataset=dev_ds,
+ batch_sampler=dev_batch_sampler,
+ collate_fn=batchify_fn,
+ num_workers=0,
+ return_list=True)
+ num_labels = 1 if train_ds.label_list == None else len(train_ds.label_list)
+
+ model = model_class.from_pretrained(
+ args.model_name_or_path, num_classes=num_labels)
+
+ # Step1: Initialize a dictionary to save the weights from the origin BERT model.
+ origin_weights = model.state_dict()
+
+ # Step2: Convert origin model to supernet.
+ sp_config = supernet(
+ expand_ratio=[1.0]) #expand_ratio=args.width_mult_list)
+ model = Convert(sp_config).convert(model)
+ # Use weights saved in the dictionary to initialize supernet.
+ utils.set_state_dict(model, origin_weights)
+ del origin_weights
+
+ super_sd = paddle.load(
+ os.path.join(args.model_name_or_path, 'model_state.pdparams'))
+ model.set_state_dict(super_sd)
+
+ # Step3: Define teacher model.
+ #print(args.model_name_or_path, "**************")
+ teacher_model = model_class.from_pretrained(
+ args.model_name_or_path, num_classes=num_labels)
+
+ # Step4: Config about distillation.
+ mapping_layers = ['ernie.embeddings']
+ for idx in range(model.ernie.config['num_hidden_layers']):
+ mapping_layers.append('ernie.encoder.layers.{}'.format(idx))
+
+ default_distill_config = {
+ 'lambda_distill': 0.1,
+ 'teacher_model': teacher_model,
+ 'mapping_layers': mapping_layers,
+ }
+ distill_config = DistillConfig(**default_distill_config)
+
+ # Step5: Config in supernet training.
+ ofa_model = OFA(model,
+ distill_config=distill_config,
+ elastic_order=['width'])
+
+ criterion = paddle.nn.loss.CrossEntropyLoss(
+ ) if train_ds.label_list else paddle.nn.loss.MSELoss()
+
+ metric = metric_class()
+
+ #### Step6: Calculate the importance of neurons and head,
+ #### and then reorder them according to the importance.
+ head_importance, neuron_importance = nlp_utils.compute_neuron_head_importance(
+ args.task_name,
+ ofa_model.model,
+ dev_data_loader,
+ loss_fct=criterion,
+ num_layers=model.ernie.config['num_hidden_layers'],
+ num_heads=model.ernie.config['num_attention_heads'])
+ reorder_neuron_head(ofa_model.model, head_importance, neuron_importance)
+
+ if paddle.distributed.get_world_size() > 1:
+ ofa_model.model = paddle.DataParallel(ofa_model.model)
+
+ if args.max_steps > 0:
+ num_training_steps = args.max_steps
+ num_train_epochs = math.ceil(num_training_steps /
+ len(train_data_loader))
+ else:
+ num_training_steps = len(train_data_loader) * args.num_train_epochs
+ num_train_epochs = args.num_train_epochs
+
+ warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion
+
+ lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
+ warmup)
+
+ # Generate parameter names needed to perform weight decay.
+ # All bias and LayerNorm parameters are excluded.
+ decay_params = [
+ p.name for n, p in model.named_parameters()
+ if not any(nd in n for nd in ["bias", "norm"])
+ ]
+
+ optimizer = paddle.optimizer.AdamW(
+ learning_rate=lr_scheduler,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=args.adam_epsilon,
+ parameters=model.parameters(),
+ weight_decay=args.weight_decay,
+ apply_decay_param_fun=lambda x: x in decay_params,
+ grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm))
+
+ global_step = 0
+ tic_train = time.time()
+ best_res = 0.0
+ for epoch in range(num_train_epochs):
+ # Step7: Set current epoch and task.
+ ofa_model.set_epoch(epoch)
+ ofa_model.set_task('width')
+
+ for step, batch in enumerate(train_data_loader):
+ global_step += 1
+ input_ids, segment_ids, labels = batch
+
+ for width_mult in args.width_mult_list:
+ # Step8: Broadcast supernet config from width_mult,
+ # and use this config in supernet training.
+ net_config = utils.dynabert_config(ofa_model, width_mult)
+ ofa_model.set_net_config(net_config)
+ logits, teacher_logits = ofa_model(
+ input_ids, segment_ids, attention_mask=[None, None])
+ rep_loss = ofa_model.calc_distill_loss()
+ if args.task_name == 'sts-b':
+ logit_loss = paddle.zeros(shape=[1], dtype='float32')
+ else:
+ logit_loss = soft_cross_entropy(logits,
+ teacher_logits.detach())
+ loss = rep_loss + args.lambda_logit * logit_loss
+ loss.backward()
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.clear_grad()
+
+ if global_step % args.logging_steps == 0:
+ logger.info(
+ "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
+ % (global_step, epoch, step, loss,
+ args.logging_steps / (time.time() - tic_train)))
+ tic_train = time.time()
+
+ if global_step % args.save_steps == 0 or global_step == num_training_steps:
+ tic_eval = time.time()
+ evaluate(teacher_model, metric, dev_data_loader, width_mult=100)
+ print("eval done total : %s s" % (time.time() - tic_eval))
+ for idx, width_mult in enumerate(args.width_mult_list):
+ net_config = utils.dynabert_config(ofa_model, width_mult)
+ #print(net_config)
+ ofa_model.set_net_config(net_config)
+ tic_eval = time.time()
+ res = evaluate(ofa_model, metric, dev_data_loader,
+ width_mult)
+ print("eval done total : %s s" % (time.time() - tic_eval))
+
+ if best_res < res:
+ output_dir = args.output_dir
+ #output_dir = os.path.join(args.output_dir,
+ #"model_%d" % global_step)
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ # need better way to get inner model of DataParallel
+ model_to_save = model._layers if isinstance(
+ model, paddle.DataParallel) else model
+ model_to_save.save_pretrained(output_dir)
+ tokenizer.save_pretrained(output_dir)
+ print("res saved", res)
+ #res = evaluate(ofa_model, metric,
+ #dev_data_loader, width_mult, True)
+ best_res = res
+ #sys.exit()
+ if global_step >= num_training_steps:
+ print("best_res: ", best_res)
+ return
+ print("best_res: ", best_res)
+
+
+def print_arguments(args):
+ """print arguments"""
+ print('----------- Configuration Arguments -----------')
+ for arg, value in sorted(vars(args).items()):
+ print('%s: %s' % (arg, value))
+ print('------------------------------------------------')
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ print_arguments(args)
+ do_train(args)
diff --git a/examples/model_compression/PP-MiniLM/ofa/run_ofa.sh b/examples/model_compression/PP-MiniLM/ofa/run_ofa.sh
new file mode 100644
index 000000000000..011ef885d3fd
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/ofa/run_ofa.sh
@@ -0,0 +1,19 @@
+export TASK_NAME=$1
+export LR=$2
+export BATCH_SIZE=$3
+export PRE_EPOCHS=$4
+export SEQ_LEN=$5
+export CUDA_VISIBLE_DEVICES=$6
+export STUDENT_DIR=$7
+
+python -u ./run_ofa.py --model_type ernie \
+ --model_name_or_path ${STUDENT_DIR} \
+ --task_name $TASK_NAME --max_seq_length ${SEQ_LEN} \
+ --batch_size ${BATCH_SIZE} \
+ --learning_rate ${LR} \
+ --num_train_epochs ${PRE_EPOCHS} \
+ --logging_steps 100 \
+ --save_steps 100 \
+ --output_dir ./ofa_models/$TASK_NAME/0.75/best_model/ \
+ --device gpu \
+ --width_mult_list 0.75
diff --git a/examples/model_compression/PP-MiniLM/pp-minilm.png b/examples/model_compression/PP-MiniLM/pp-minilm.png
new file mode 100644
index 000000000000..148c7406a783
Binary files /dev/null and b/examples/model_compression/PP-MiniLM/pp-minilm.png differ
diff --git a/examples/model_compression/PP-MiniLM/quantization/quant_all.sh b/examples/model_compression/PP-MiniLM/quantization/quant_all.sh
new file mode 100644
index 000000000000..4414c3d6a621
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/quantization/quant_all.sh
@@ -0,0 +1,8 @@
+MODEL_DIR=../ofa/ofa_models/
+for task in AFQMC TNEWS IFLYTEK CMNLI OCNLI CLUEWSC2020 CSL
+
+do
+
+ python quant_post.py --task_name ${task} --input_dir ${MODEL_DIR}/${task}/0.75/sub_static
+
+done
diff --git a/examples/model_compression/PP-MiniLM/quantization/quant_post.py b/examples/model_compression/PP-MiniLM/quantization/quant_post.py
new file mode 100644
index 000000000000..c8d05612a166
--- /dev/null
+++ b/examples/model_compression/PP-MiniLM/quantization/quant_post.py
@@ -0,0 +1,167 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import six
+import sys
+import os
+import time
+import argparse
+from functools import partial
+import numpy as np
+
+import paddle
+from paddle.metric import Accuracy
+
+import paddlenlp
+import paddleslim
+from paddlenlp.data import Stack, Tuple, Pad, Dict
+from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument(
+ "--task_name", type=str, default="afqmc", required=False, help="task_name")
+parser.add_argument(
+ "--input_dir",
+ type=str,
+ default="afqmc",
+ required=False,
+ help="input task model dire")
+
+args = parser.parse_args()
+
+METRIC_CLASSES = {
+ "afqmc": Accuracy,
+ "tnews": Accuracy,
+ "iflytek": Accuracy,
+ "ocnli": Accuracy,
+ "cmnli": Accuracy,
+ "cluewsc2020": Accuracy,
+ "csl": Accuracy,
+}
+
+MODEL_CLASSES = {"ernie": (ErnieForSequenceClassification, ErnieTokenizer), }
+
+
+def convert_example(example,
+ tokenizer,
+ label_list,
+ max_seq_length=512,
+ is_test=False):
+ """convert a glue example into necessary features"""
+ if not is_test:
+ # `label_list == None` is for regression task
+ label_dtype = "int64" if label_list else "float32"
+ # Get the label
+ label = example['label']
+ label = np.array([label], dtype=label_dtype)
+ # Convert raw text to feature
+ if 'sentence' in example:
+ example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
+ elif 'sentence1' in example:
+ example = tokenizer(
+ example['sentence1'],
+ text_pair=example['sentence2'],
+ max_seq_len=max_seq_length)
+ elif 'keyword' in example: # CSL
+ sentence1 = " ".join(example['keyword'])
+ example = tokenizer(
+ sentence1, text_pair=example['abst'], max_seq_len=max_seq_length)
+ elif 'target' in example: # wsc
+ text, query, pronoun, query_idx, pronoun_idx = example['text'], example[
+ 'target']['span1_text'], example['target']['span2_text'], example[
+ 'target']['span1_index'], example['target']['span2_index']
+ text_list = list(text)
+ assert text[pronoun_idx:(pronoun_idx + len(pronoun)
+ )] == pronoun, "pronoun: {}".format(pronoun)
+ assert text[query_idx:(query_idx + len(query)
+ )] == query, "query: {}".format(query)
+ if pronoun_idx > query_idx:
+ text_list.insert(query_idx, "_")
+ text_list.insert(query_idx + len(query) + 1, "_")
+ text_list.insert(pronoun_idx + 2, "[")
+ text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]")
+ else:
+ text_list.insert(pronoun_idx, "[")
+ text_list.insert(pronoun_idx + len(pronoun) + 1, "]")
+ text_list.insert(query_idx + 2, "_")
+ text_list.insert(query_idx + len(query) + 2 + 1, "_")
+ text = "".join(text_list)
+ example = tokenizer(text, max_seq_len=max_seq_length)
+
+ if not is_test:
+ return example['input_ids'], example['token_type_ids'], label
+ else:
+ return example['input_ids'], example['token_type_ids']
+
+
+def quant_post(args, batch_size=8, algo='avg'):
+ paddle.enable_static()
+ place = paddle.set_device("gpu")
+ exe = paddle.static.Executor(place)
+ args.task_name = args.task_name.lower()
+
+ train_ds = paddlenlp.datasets.load_dataset(
+ "clue", args.task_name, splits="dev")
+
+ tokenizer = ErnieTokenizer.from_pretrained(
+ "../ernie-batchbatch-50w_400000/best_models/AFQMC/")
+
+ trans_func = partial(
+ convert_example,
+ tokenizer=tokenizer,
+ label_list=train_ds.label_list,
+ max_seq_length=128,
+ is_test=True)
+ train_ds = train_ds.map(trans_func, lazy=True)
+
+ def test():
+ batch_data = [[], []]
+ for data in train_ds:
+ batch_data[0].append(data[0])
+ batch_data[1].append(data[1])
+ if len(batch_data[0]) == batch_size:
+ input_ids = Pad(axis=0, pad_val=0)(batch_data[0])
+ segment_ids = Pad(axis=0, pad_val=0)(batch_data[1])
+ ones = np.ones_like(input_ids, dtype="int64")
+ seq_length = np.cumsum(ones, axis=-1)
+
+ position_ids = seq_length - ones
+ attention_mask = np.expand_dims(
+ (input_ids == 0).astype("float32") * -1e9, axis=[1, 2])
+ yield [input_ids, segment_ids]
+ batch_data = [[], []]
+
+ paddleslim.quant.quant_post_static(
+ exe,
+ args.input_dir,
+ os.path.join(args.task_name + '_quant_models', algo + str(batch_size)),
+ save_model_filename='int8.pdmodel',
+ save_params_filename='int8.pdiparams',
+ algo=algo,
+ hist_percent=0.9999,
+ batch_generator=test,
+ model_filename='float.pdmodel',
+ params_filename='float.pdiparams',
+ quantizable_op_type=['matmul', 'matmul_v2'],
+ weight_bits=8,
+ weight_quantize_type='channel_wise_abs_max',
+ batch_nums=1, )
+
+
+if __name__ == '__main__':
+ paddle.enable_static()
+ for batch_size in [4, 8]:
+ for algo in ['abs_max', 'avg', 'mse', 'hist']:
+ quant_post(args, batch_size, algo)
diff --git a/paddlenlp/transformers/distill_utils.py b/paddlenlp/transformers/distill_utils.py
index 037cd8ceb849..7b296b308882 100644
--- a/paddlenlp/transformers/distill_utils.py
+++ b/paddlenlp/transformers/distill_utils.py
@@ -39,7 +39,6 @@ def calc_multi_relation_loss(loss_fct,
Calculates loss for multiple Q-Q, K-K and V-V relation. It supports
head-head relation, sample-sample relation and origin token-token relation.
The final loss value could be balanced by weight `alpha` and `beta`.
-
Args:
loss_fct (callable):
Loss function for distillation. It only supports kl_div loss now.
@@ -59,11 +58,9 @@ def calc_multi_relation_loss(loss_fct,
beta (float):
The weight for sample-sample relation.
Defaults to 0.0.
-
Returns:
Tensor: Weighted loss of token-token loss, head-head loss and
sample-sample loss.
-
"""
# Initialize head_num
if num_relation_heads > 0 and num_relation_heads != s.shape[1]:
@@ -140,7 +137,6 @@ def calc_multi_relation_loss(loss_fct,
def calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0):
"""
Calculates loss for Q-Q, K-K, V-V relation from MiniLMv2.
-
Args:
loss_fct (callable):
Loss function for distillation. It only supports kl_div loss now.
@@ -154,10 +150,8 @@ def calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0):
The number of relation heads. 0 means `num_relation_heads` equals
to origin head num.
Defaults to 0.
-
Returns:
Tensor: MiniLM loss value.
-
"""
# Initialize head_num
if num_relation_heads > 0 and num_relation_heads != s.shape[1]:
@@ -197,7 +191,6 @@ def to_distill(self,
expose attributes `outputs.q`, `outputs.k`, `outputs.v`,
`outputs.scaled_qks`, `outputs.hidden_states`and `outputs.attentions` of
the object for distillation.
-
It could be returned intermediate tensor using in MiniLM and TinyBERT
strategy.
"""
@@ -435,4 +428,4 @@ def bert_forward(self, input_ids, token_type_ids=None, attention_mask=None):
sequence_output, pooled_output = model(input_ids, token_type_ids,
attention_mask)
- return encoder.attentions, encoder.hidden_states
+ return encoder.attentions, encoder.hidden_states
\ No newline at end of file