diff --git a/examples/model_compression/PP-MiniLM/README.md b/examples/model_compression/PP-MiniLM/README.md new file mode 100644 index 000000000000..67f750f14ff6 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/README.md @@ -0,0 +1,303 @@ +# PP-MiniLM中文特色小模型 + +PP-MiniLM 中文特色小模型案例旨在提供训推一体的高精度、高性能小模型解决方案。 + +当前解决方案依托业界领先的 Task Agnostic 模型蒸馏技术、裁剪技术、量化技术,使得小模型兼具推理速度快、模型效果好、参数规模小的 3 大特点。 + +- 推理速度快:我们集成了 PaddleSlim 的裁剪、量化技术进一步对小模型进行压缩,保证模型推理速度达到原先的2.18倍; + +- 精度高: 我们以 MiniLMv2 提出的 Multi-Head Self-Attention Relation Distillation 技术为基础,通过引入样本间关系知识蒸馏做了进一步算法优化。我们的6层、hidden size为768的模型,在CLUE上的平均准确率分别高于TinyBERT、UER-py RoBERTa同样大小的模型2.66%、1.51%。 + +- 参数规模小:依托 PaddleSlim 裁剪技术,在精度几乎无损(-0.15)条件下将模型宽度压缩 1/4。 + +整体效果一览: + +| Model | #Params | #FLOPs | Speedup | AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE平均值 | +| ----------------------- | ------------- | ------ | ------- | ----- | ----- | ------- | ----- | ----- | ----- | ----- | ---------- | +| Bert-base | 102.3M | | 1.00x | 74.17 | 57.17 | 61.14 | | 75.08 | 80.26 | 81.47 | | +| TinyBERT6 | 59.7M | | 1.64x | 72.22 | 55.82 | 58.10 | 79.53 | 74.00 | 75.99 | 80.57 | 70.89 | +| UER-py RoBERTa L6- H768 | 59.7M | | 1.64x | 69.74 | 66.36 | 59.95 | 77.00 | 71.39 | 71.05 | 82.83 | 71.19 | +| ERNIE-Tiny | 90.7M | | 1.76x | 70.78 | 55.70 | 59.95 | 75.40 | 70.98 | 67.43 | 76.60 | 68.12 | +| PP-MiniLM 6L-768H | 59.7M | | 1.64x | 74.28 | 57.33 | 61.72 | 81.06 | 76.2 | 86.51 | 78.77 | 73.70 | +| PP-MiniLM裁剪后 | 49.1M (+裁剪) | | 1.88x | 73.82 | 57.33 | 61.60 | 81.38 | 76.20 | 85.52 | 79.00 | 73.55 | +| PP-MiniLM量化后 | 49.2M(+量化) | | 3.56x | 73.61 | 57.18 | 61.49 | 81.26 | 76.31 | 84.54 | 77.67 | 73.15 | +| | | | | | | | | | | | | + + + +方案流程一览: + +

+
+整体流程图 +

+ +以下是本范例模型的简要目录结构及说明: + +```shell +. +├── general_ditill # 通用蒸馏目录 +│ └── general_distill.py # 通用蒸馏脚本 +│ └── run.sh # 通用蒸馏启动脚本 +├── finetuning # 下游任务训练目录 +│ └── run_clue.py # clue上的微调脚本 +│ └── run_clue.sh # clue上的微调启动脚本 +│ └── run_one_search.sh # 单数据集下精调脚本 +│ └── run_all_search.sh # clue数据集下精调脚本 +│ └── export_model.sh # 导出部署模型脚本 +├── ofa # ofa裁剪、蒸馏目录 +│ └── run_ofa.py # ofa裁剪、蒸馏脚本 +│ └── run_ofa.sh # ofa裁剪、蒸馏启动脚本 +│ └── export_model.py # 导出ofa训练得到的子模型 +├── quantization # 离线量化目录 +│ └── quant_post.py # 离线量化脚本 +│ └── quant.sh # 离线量化脚本 +├── inference # 预测目录 +│ └── infer.py # 预测脚本 +│ └── infer.sh # 预测启动脚本 +│ └── infer_all.sh # 批量预测量化模型启动脚本 +└── README # 文档,本文件 + +``` + +## 通用蒸馏(可选) + +### 环境要求 + +本实验基于NVIDIA Tesla V100 32G 8卡进行,训练周期约为2-3天。若资源有限,可以直接下载这一步得到的模型跳过此步骤,直接使用链接的模型用下游任务数据进行微调。 + + +### 原理介绍 + +PP-MiniLM模型的蒸馏方法介绍: + +本方案在MiniLMv2提出的Multi-Head Self-Attention Relation Distillation蒸馏算法的基础上,通过引入样本间关系知识蒸馏做了进一步算法优化。即用24层的Roberta-wwm-ext-large教师模型的第20层对6层学生模型PP-MiniLM第6层的Q-Q、K-K、V-V之间的样本间关系进行蒸馏。首先将学生、教师用于蒸馏的层上的head数进行统一,然后将Q、K、V的shape均转置成[seq_len, head_num, batch_size, head_dim],再对Q-Q、K-K、V-V之间的关系进行蒸馏。这种方法比使用原始MiniLMv2算法在CLUE上平均准确率高0.36。 + + +### 数据介绍 + +将数据分割成64个文件,放在目录dataset下。 + + +### 运行方式 + +```shell +cd general_distill +sh run.sh # 包含general_distill.py的运行配置 +cd .. +``` + +其中general_distill.py参数释义如下: + +- `model_type` 指示了学生模型类型,当前仅支持'ernie'、'roberta'。 +- `num_relation_heads` relation heads的个数,一般对于large size的教师模型是64,对于base size的教师模型是48。 +- `teacher_model_type`指示了教师模型类型,当前仅支持'ernie'、'roberta'。 +- `teacher_layer_index`蒸馏时使用的教师模型的层数 +- `student_layer_index` 蒸馏时使用的学生模型的层数 +- `teacher_model_name_or_path`教师模型的名称,例如`'roberta-wwm-ext-large'` +- `max_seq_length` 最大的样本长度 +- `num_layers` 学生模型的层数,目前仅支持2,4,6 +- `logging_steps` 日志间隔 +- `max_steps` 最大迭代次数 +- `warmup_steps` 学习率增长得到`learning_rate`所需要的步数 +- `save_steps`保存模型的间隔步数 +- `weight_decay` 表示AdamW优化器中使用的weight_decay的系数。 +- `output_dir`训练相关文件以及模型保存的输出路径 +- `device`设备选择,推荐使用gpu +- `input_dir` 训练数据目录 +- `use_amp` 是否使用混合精度训练,默认False +- `alpha`head间关系的权重,默认0.0 +- `beta`样本间关系的权重,默认0.0 + +将最终得到的模型绝对路径保存至`$GENERAL_MODEL_DIR`,例如: + +```shell +GENERAL_MODEL_DIR=PaddleNLP/examples/model_compression/PP-MiniLM/general_distill/pretrain/model_400000 +``` + +## 在下游任务上Fine-tuning + +### 数据介绍 + +本实验基于 CLUE 数据集,运行 Fine-tune 脚本会自动下载该数据集到 `~/.paddlenlp/datasets/Clue/` 目录(linux下)。 + +使用以下超参范围对第一步通用蒸馏得到的通用模型`GENERAL_MODEL_DIR`进行精调: + +- batch sizes: 16, 32, 64 +- learning rates: 3e-5, 5e-5, 1e-4 + +### 启动方式 + +基于如下超参范围对第一步蒸馏产出的小模型 `GENERAL_MODEL_DIR` 进行 Grid Search 超参寻优: + +```shell +cd finetuning +sh run_all_search.sh $GENERAL_MODEL_DIR +``` + +如果只是单个数据集上特定batch_size、learning_rate的微调,可以参考: + +``` +sh run_clue.sh CLUEWSC2020 1e-4 32 3 128 0 $GENERAL_MODEL_DIR +``` + +其中每个参数依次表示:clue中的任务名称、learning_rate、batch_size、epochs、max_seq_len、card id + +### 模型精度 + +经过精调后,CLUE上各个任务上的精度如下表: + +| AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE平均值 | +| ----- | ----- | ------- | ----- | ----- | ----- | ----- | ---------- | +| 74.28 | 57.33 | 61.72 | 81.06 | 76.20 | 86.51 | 78.77 | 73.70 | + +### 你可以这样导出Fine-tuning之后的模型直接用于部署 + +假设fine-tuning之后的模型保存的地址为`$GENERAL_MODEL_DIR/models/CLUEWSC2020/1e-4_32`,可以运行下方命令对动态图模型导出为可用于部署的静态图模型: + +```shell +python export_model.py --model_type ernie --model_path $GENERAL_MODEL_DIR/models/CLUEWSC2020/1e-4_32 --output_path fine_tuned_infer_model/float +``` + +## 使用PaddleSlim OFA对任务上的模型进行裁剪 + +这一步主要使用PaddleSlim ofa功能对下游任务上的模型宽度进行裁剪。如果执行这部分内容,需要安装paddleslim的最新包: + +```shell +pip install -U paddleslim -i https://pypi.org/simple +cd ofa +``` +该过程会以finetuning后得到的模型当作教师模型,蒸馏宽度为3/4的学生模型。经过我们的实验,在6L768H 条件下,模型宽度压缩为原来的 3/4, 精度几乎无损(-0.15)。 + +### 裁剪、蒸馏过程的启动脚本 + +假设需要对上一步finetuning得到的模型`$GENERAL_MODEL_DIR/models/CLUEWSC2020/1e-4_32`进行进一步的裁剪,其中learning_rate、batch_size可以继续使用fine-tuning时的参数,例如:可以使用如下命令: + +```shell +sh run_ofa.sh CLUEWSC2020 5e-5 16 50 128 4 ../general_distill/ernie-batchbatch-50w_400000/models/CLUEWSC2020/1e-4_32/ +``` + +执行完成后,模型保存的路径位于`ofa_models/CLUEWSC2020/0.75/best_model/`。 + +### 导出裁剪后的模型: + +这一步可以同时得到动态图、静态图的模型参数 + +以CLUEWSC2020数据集为例,导出模型: + +```shell +MODEL_PATH=ofa_models +TASK_NAME=CLUEWSC2020 +sh export.sh $MODEL_PATH $TASK_NAME +``` + +或者可以批量导出各个任务上的模型: + +```shell +sh export_all.sh +``` + +最终模型保存的位置位于` ${MODEL_PATH}/${TASK_NAME}/0.75/sub_static/float` + +```shell +cd .. +``` + +### 模型精度 + +经过裁剪、蒸馏后,CLUE上各个任务上的精度如下表所示。相比起裁剪前,CLUE数据集上平均值下降0.15。模型的参数量由59.7M到49.1M。 + +| AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE平均值 | +| ----- | ----- | ------- | ----- | ----- | ----- | ----- | ---------- | +| 74.28 | 57.33 | 61.60 | 81.38 | 76.20 | 85.52 | 79.00 | 73.55 | + + + +## 借助PaddleSlim的量化功能进一步减少模型大小 + +```shell +cd quantization +``` + +离线量化的介绍: + +这一步我们可以将float32的模型通过paddleslim提供的离线量化API,无需再次训练,直接得到量化的模型。这一步使用了mse、avg、abs_max、hist多种策略,并使用4、8两种量化时的校准集数量。 + +运行如下的脚本可以得到 + +```shell +python quant_post.py --task_name $TASK_NAME --input_dir ${MODEL_DIR}/${TASK_NAME}/0.75/sub_static +``` + +可以批量对所有数据集下的float模型进行量化: + +```shell +sh quant_all.sh +``` + +``` +cd .. +``` + +### 模型精度 + +再经过量化后,CLUE上各个任务上的精度如下表,比上一步(裁剪后)下降了0.4: + +| AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE平均值 | +| ----- | ----- | ------- | ----- | ----- | ----- | ----- | ---------- | +| 73.61 | 57.18 | 61.49 | 81.26 | 76.31 | 84.54 | 77.67 | 73.15 | + +## 利用Paddle Inference进行预测部署 + +### 环境要求: + +这一步需要依赖paddle2.2.1中Paddle Inference进行预测,如果需要得到更明显的加速效果,推荐在NVIDA Tensor Core GPU(如T4、A10、A100)上进行测试,本案例基于T4测试。若在V系列卡上测试,由于其不支持Int8 Tensor Core,加速效果将达不到本文档表格中的加速效果。 + +由于开启了动态shape功能,因此需要设置获取shape的范围。Paddle Inference提供了相应的接口,即首先通过离线输入数据来统计出所有临时tensor的shape范围,TRT子图的tensor输入shape范围可直接根据上一步tune出来的结果来设置,即可完成自动shape范围设置。统计完成后,只需设置统计结果路径,即可启用tuned_dynamic_shape功能。在本案例中,只需要先设置--collect_shape参数,运行infer.py,然后再取消传入这个参数,再次运行infer.py。例如: + +INT8预测脚本: + +```shell +cd inference + +python infer.py --task_name ${task} --model_path ../quantization/${task}_quant_models/${algo}${bs}/int8 --int8 --use_trt --collect_shape # 生成shape range info文件 +python infer.py --task_name ${task} --model_path ../quantization/${task}_quant_models/${algo}${bs}/int8 --int8 --use_trt # load shape range info文件进行预测 +``` +如果想要批量对Int8模型进行预测并比较不同量化模型的效果,可以使用如下的脚本批量预测: + +```shell +sh infer_all.sh +``` + +FP32预测脚本: + +```shell +python infer.py --task_name ${task} --model_path $MODEL_PATH --use_trt --collect_shape +python infer.py --task_name ${task} --model_path $MODEL_PATH --use_trt +``` + +### 性能测试 +本案例是在NVIDIA Tesla T4 单卡上,cuda11.1、cudnn8.1、TensorRT7.2,使用inference/infer.py脚本,对量化后的模型进行预测。 + +测试性能时采用了TNEWS数据集下的模型,下表三行分别是微调后的模型、OFA裁剪蒸馏后的模型、量化方法为mse、校准集数量为4的量化模型,计算dev上预测的总耗时(去除前20个steps)。 + +可以发现借助PaddleSlim裁剪、量化后的模型比原BERT-base模型推理速度加速255.86%,其中裁剪可以加速87.98%。 + +| | 平均耗时(s) | 加速比 | +| ------------------ | ----------- | ------- | +| BERT | 20.64 | 0 | +| FP32 | 12.61 | 63.68% | +| FP32+裁剪 | 10.98 | 87.98% | +| FP32+裁剪+INT8量化 | 5.80 | 255.86% | + + +INT8预测脚本: + +```shell + +sh infer.sh +``` + +```shell +cd .. +``` diff --git a/examples/model_compression/PP-MiniLM/finetuning/export_model.py b/examples/model_compression/PP-MiniLM/finetuning/export_model.py new file mode 100644 index 000000000000..1e1e4fe459a6 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/finetuning/export_model.py @@ -0,0 +1,78 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import paddle + +from run_clue import MODEL_CLASSES + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_path", + default=None, + type=str, + required=True, + help="Path of the trained model to be exported.", ) + parser.add_argument( + "--output_path", + default=None, + type=str, + required=True, + help="The output file prefix used to save the exported inference model.", + ) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + # build model and load trained parameters + model = model_class.from_pretrained(args.model_path) + # switch to eval model + model.eval() + # convert to static graph with specific input description + model = paddle.jit.to_static( + model, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None], dtype="int64"), # input_ids + paddle.static.InputSpec( + shape=[None, None], dtype="int64") # segment_ids + ]) + # save converted static graph model + paddle.jit.save(model, args.output_path) + # also save tokenizer for inference usage + tokenizer = tokenizer_class.from_pretrained(args.model_path) + tokenizer.save_pretrained(os.path.dirname(args.output_path)) + + +if __name__ == "__main__": + main() diff --git a/examples/model_compression/PP-MiniLM/finetuning/run_all_search.sh b/examples/model_compression/PP-MiniLM/finetuning/run_all_search.sh new file mode 100644 index 000000000000..f19ad7bfaaa5 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/finetuning/run_all_search.sh @@ -0,0 +1,35 @@ +# $1 means GENERAL_DIR + +# The penultimate parameter is the card id, this script can be changed if necessary +bash run_one_search.sh $1 afqmc 0 & +bash run_one_search.sh $1 tnews 1 & +bash run_one_search.sh $1 ifly 2 & +bash run_one_search.sh $1 ocnli 3 & +bash run_one_search.sh $1 csl 4 & +bash run_one_search.sh $1 wsc 5 & + +# Because the CMNLI data set is significantly larger than other data sets, +# it needs to be placed on different cards. +lr=1e-4 +bs=16 +sh run_clue.sh CMNLI $lr $bs 3 128 0 $1 > $1/cmnli/${lr}_${bs}_3_128.log & +bs=32 +sh run_clue.sh CMNLI $lr $bs 3 128 1 $1 > $1/cmnli/${lr}_${bs}_3_128.log & +bs=64 +sh run_clue.sh CMNLI $lr $bs 3 128 2 $1 > $1/cmnli/${lr}_${bs}_3_128.log & + +lr=5e-5 +bs=16 +sh run_clue.sh CMNLI $lr $bs 3 128 3 $1 > $1/cmnli/${lr}_${bs}_3_128.log & +bs=32 +sh run_clue.sh CMNLI $lr $bs 3 128 4 $1 > $1/cmnli/${lr}_${bs}_3_128.log & +bs=64 +sh run_clue.sh CMNLI $lr $bs 3 128 5 $1 > $1/cmnli/${lr}_${bs}_3_128.log & + +lr=3e-5 +bs=16 +sh run_clue.sh CMNLI $lr $bs 3 128 6 $1 > $1/cmnli/${lr}_${bs}_3_128.log & +bs=32 +sh run_clue.sh CMNLI $lr $bs 3 128 5 $1 > $1/cmnli/${lr}_${bs}_3_128.log & +bs=64 +sh run_clue.sh CMNLI $lr $bs 3 128 7 $1 > $1/cmnli/${lr}_${bs}_3_128.log & diff --git a/examples/model_compression/PP-MiniLM/finetuning/run_clue.py b/examples/model_compression/PP-MiniLM/finetuning/run_clue.py new file mode 100644 index 000000000000..3c3cc18b61ef --- /dev/null +++ b/examples/model_compression/PP-MiniLM/finetuning/run_clue.py @@ -0,0 +1,450 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import sys +import random +import time +import math +import distutils.util +from functools import partial + +import numpy as np +import paddle +from paddle.io import DataLoader +import paddle.nn as nn +from paddle.metric import Accuracy + +from paddlenlp.datasets import load_dataset +from paddlenlp.data import Stack, Tuple, Pad, Dict +from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer, BertModel +from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer +from paddlenlp.transformers import LinearDecayWithWarmup + +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + +METRIC_CLASSES = { + "afqmc": Accuracy, + "tnews": Accuracy, + "iflytek": Accuracy, + "ocnli": Accuracy, + "cmnli": Accuracy, + "cluewsc2020": Accuracy, + "csl": Accuracy, +} + +MODEL_CLASSES = { + "bert": (BertForSequenceClassification, BertTokenizer), + "ernie": (ErnieForSequenceClassification, ErnieTokenizer), +} + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--task_name", + default=None, + type=str, + required=True, + help="The name of the task to train selected in the list: " + + ", ".join(METRIC_CLASSES.keys()), ) + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--output_dir", + default="best_clue_model", + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--learning_rate", + default=1e-4, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--num_train_epochs", + default=3, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--logging_steps", + type=int, + default=100, + help="Log every X updates steps.") + parser.add_argument( + "--save_steps", + type=int, + default=100, + help="Save checkpoint every X updates steps.") + parser.add_argument( + "--batch_size", + default=32, + type=int, + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--weight_decay", + default=0.0, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--warmup_steps", + default=0, + type=int, + help="Linear warmup over warmup_steps. If > 0: Override warmup_proportion" + ) + parser.add_argument( + "--warmup_proportion", + default=0.1, + type=float, + help="Linear warmup proportion over total steps.") + parser.add_argument( + "--adam_epsilon", + default=1e-6, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--do_train", + type=distutils.util.strtobool, + default=True, + help="Whether do train.") + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="random seed for initialization") + parser.add_argument( + "--device", + default="gpu", + type=str, + help="The device to select to train the model, is must be cpu/gpu/xpu.") + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="The max value of grad norm.") + args = parser.parse_args() + return args + + +def set_seed(args): + # Use the same data seed(for data shuffle) for all procs to guarantee data + # consistency after sharding. + random.seed(args.seed) + np.random.seed(args.seed) + # Maybe different op seeds(for dropout) for different procs is better. By: + # `paddle.seed(args.seed + paddle.distributed.get_rank())` + paddle.seed(args.seed) + + +@paddle.no_grad() +def evaluate(model, loss_fct, metric, data_loader): + model.eval() + metric.reset() + for batch in data_loader: + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids) + loss = loss_fct(logits, labels) + correct = metric.compute(logits, labels) + metric.update(correct) + res = metric.accumulate() + print("eval loss: %f, acc: %s, " % (loss.numpy(), res), end='') + model.train() + return res + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=512, + is_test=False): + """convert a glue example into necessary features""" + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + label = example['label'] + label = np.array([label], dtype=label_dtype) + # Convert raw text to feature + if 'sentence' in example: + example = tokenizer(example['sentence'], max_seq_len=max_seq_length) + elif 'sentence1' in example: + example = tokenizer( + example['sentence1'], + text_pair=example['sentence2'], + max_seq_len=max_seq_length) + elif 'keyword' in example: # CSL + sentence1 = " ".join(example['keyword']) + example = tokenizer( + sentence1, text_pair=example['abst'], max_seq_len=max_seq_length) + elif 'target' in example: # wsc + text, query, pronoun, query_idx, pronoun_idx = example['text'], example[ + 'target']['span1_text'], example['target']['span2_text'], example[ + 'target']['span1_index'], example['target']['span2_index'] + text_list = list(text) + assert text[pronoun_idx:(pronoun_idx + len(pronoun) + )] == pronoun, "pronoun: {}".format(pronoun) + assert text[query_idx:(query_idx + len(query) + )] == query, "query: {}".format(query) + if pronoun_idx > query_idx: + text_list.insert(query_idx, "_") + text_list.insert(query_idx + len(query) + 1, "_") + text_list.insert(pronoun_idx + 2, "[") + text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") + else: + text_list.insert(pronoun_idx, "[") + text_list.insert(pronoun_idx + len(pronoun) + 1, "]") + text_list.insert(query_idx + 2, "_") + text_list.insert(query_idx + len(query) + 2 + 1, "_") + text = "".join(text_list) + example = tokenizer(text, max_seq_len=max_seq_length) + if not is_test: + return example['input_ids'], example['token_type_ids'], label + else: + return example['input_ids'], example['token_type_ids'] + + +def do_eval(args): + paddle.set_device(args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + + args.task_name = args.task_name.lower() + metric_class = METRIC_CLASSES[args.task_name] + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + dev_ds = load_dataset('clue', args.task_name, splits='dev') + + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=dev_ds.label_list, + max_seq_length=args.max_seq_length) + + dev_ds = dev_ds.map(trans_func, lazy=True) + dev_batch_sampler = paddle.io.BatchSampler( + dev_ds, batch_size=args.batch_size, shuffle=False) + + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment + Stack(dtype="int64" if dev_ds.label_list else "float32") # label + ): fn(samples) + + dev_data_loader = DataLoader( + dataset=dev_ds, + batch_sampler=dev_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + num_classes = 1 if dev_ds.label_list == None else len(dev_ds.label_list) + model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_classes) + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + metric = metric_class() + best_acc = 0.0 + global_step = 0 + tic_train = time.time() + model.eval() + metric.reset() + for batch in dev_data_loader: + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids) + correct = metric.compute(logits, labels) + metric.update(correct) + res = metric.accumulate() + print("acc: %s\n, " % (res), end='') + + +def do_train(args): + paddle.set_device(args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + + args.task_name = args.task_name.lower() + metric_class = METRIC_CLASSES[args.task_name] + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + train_ds = load_dataset('clue', args.task_name, splits='train') + + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=train_ds.label_list, + max_seq_length=args.max_seq_length) + train_ds = train_ds.map(trans_func, lazy=True) + train_batch_sampler = paddle.io.DistributedBatchSampler( + train_ds, batch_size=args.batch_size, shuffle=True) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment + Stack(dtype="int64" if train_ds.label_list else "float32") # label + ): fn(samples) + train_data_loader = DataLoader( + dataset=train_ds, + batch_sampler=train_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + dev_ds = load_dataset('clue', args.task_name, splits='dev') + + dev_ds = dev_ds.map(trans_func, lazy=True) + dev_batch_sampler = paddle.io.BatchSampler( + dev_ds, batch_size=args.batch_size, shuffle=False) + dev_data_loader = DataLoader( + dataset=dev_ds, + batch_sampler=dev_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list) + model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_classes) + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + if args.max_steps > 0: + num_training_steps = args.max_steps + num_train_epochs = math.ceil(num_training_steps / + len(train_data_loader)) + else: + num_training_steps = len(train_data_loader) * args.num_train_epochs + num_train_epochs = args.num_train_epochs + + warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion + + lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, + warmup) + + # Generate parameter names needed to perform weight decay. + # All bias and LayerNorm parameters are excluded. + decay_params = [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + beta1=0.9, + beta2=0.999, + epsilon=args.adam_epsilon, + parameters=model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in decay_params, + grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm)) + + loss_fct = paddle.nn.loss.CrossEntropyLoss( + ) if train_ds.label_list else paddle.nn.loss.MSELoss() + + metric = metric_class() + best_acc = 0.0 + global_step = 0 + tic_train = time.time() + for epoch in range(num_train_epochs): + for step, batch in enumerate(train_data_loader): + global_step += 1 + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids) + loss = loss_fct(logits, labels) + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + if global_step % args.logging_steps == 0: + print( + "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s" + % (global_step, num_training_steps, epoch, step, + paddle.distributed.get_rank(), loss, optimizer.get_lr(), + args.logging_steps / (time.time() - tic_train))) + tic_train = time.time() + if global_step % args.save_steps == 0 or global_step == num_training_steps: + tic_eval = time.time() + acc = evaluate(model, loss_fct, metric, dev_data_loader) + print("eval done total : %s s" % (time.time() - tic_eval)) + if acc > best_acc: + best_acc = acc + output_dir = args.output_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + #Need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + if global_step >= num_training_steps: + print("best_acc: ", best_acc) + return + print("best_acc: ", best_acc) + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + if args.do_train: + do_train(args) + else: + do_eval(args) diff --git a/examples/model_compression/PP-MiniLM/finetuning/run_clue.sh b/examples/model_compression/PP-MiniLM/finetuning/run_clue.sh new file mode 100644 index 000000000000..ad74187f5a4b --- /dev/null +++ b/examples/model_compression/PP-MiniLM/finetuning/run_clue.sh @@ -0,0 +1,25 @@ + +export TASK_NAME=$1 +export LR=$2 +export BS=$3 +export EPOCH=$4 +export MAX_SEQ_LEN=$5 +export CUDA_VISIBLE_DEVICES=$6 +export MODEL_PATH=$7 + +python -u ./run_clue.py \ + --model_type ernie \ + --model_name_or_path ${MODEL_PATH} \ + --task_name ${TASK_NAME} \ + --max_seq_length ${MAX_SEQ_LEN} \ + --batch_size ${BS} \ + --learning_rate ${LR} \ + --num_train_epochs ${EPOCH} \ + --logging_steps 100 \ + --seed 42 \ + --save_steps 100 \ + --warmup_proportion 0.1 \ + --weight_decay 0.01 \ + --adam_epsilon 1e-8 \ + --output_dir ${MODEL_PATH}/models/${TASK_NAME}/${LR}_${BS}/ \ + --device gpu \ diff --git a/examples/model_compression/PP-MiniLM/finetuning/run_one_search.sh b/examples/model_compression/PP-MiniLM/finetuning/run_one_search.sh new file mode 100644 index 000000000000..ef0871d71302 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/finetuning/run_one_search.sh @@ -0,0 +1,54 @@ +OUTPUT_DIR=$1 +TASK_NAME=$2 + +mkdir ${OUTPUT_DIR}/afqmc +mkdir ${OUTPUT_DIR}/tnews +mkdir ${OUTPUT_DIR}/ifly +mkdir ${OUTPUT_DIR}/ocnli +mkdir ${OUTPUT_DIR}/wsc +mkdir ${OUTPUT_DIR}/csl +mkdir ${OUTPUT_DIR}/cmnli + + +for lr in 1e-4 5e-5 3e-5 +do + for bs in 16 32 64 + do + echo bs: $bs, lr: $lr + + if [ $TASK_NAME == afqmc ] + then + sh run_clue.sh AFQMC $lr $bs 3 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/afqmc/${lr}_${bs}_3_128.log + fi + + if [ $TASK_NAME == tnews ] + then + sh run_clue.sh TNEWS $lr $bs 3 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/tnews/${lr}_${bs}_3_128.log + fi + + if [ $TASK_NAME == ifly ] + then + sh run_clue.sh IFLYTEK $lr $bs 6 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/ifly/${lr}_${bs}_6_128.log + fi + + if [ $TASK_NAME == ocnli ] + then + sh run_clue.sh OCNLI $lr $bs 6 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/ocnli/${lr}_${bs}_6_128.log + fi + + if [ $TASK_NAME == wsc ] + then + sh run_clue.sh CLUEWSC2020 $lr $bs 50 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/wsc/${lr}_${bs}_50_128.log + fi + + if [ $TASK_NAME == csl ] + then + sh run_clue.sh CSL $lr $bs 8 256 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/csl/${lr}_${bs}_8_256.log + fi + + if [ $TASK_NAME == cmnli ] + then + sh run_clue.sh CMNLI $lr $bs 3 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/cmnli/${lr}_${bs}_3_128.log + fi +done + diff --git a/examples/model_compression/PP-MiniLM/general_distill/general_distill.py b/examples/model_compression/PP-MiniLM/general_distill/general_distill.py new file mode 100644 index 000000000000..df26030eeb02 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/general_distill/general_distill.py @@ -0,0 +1,498 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random +import time +from functools import partial +from concurrent.futures import ThreadPoolExecutor +import distutils.util +import math + +import numpy as np +import paddle +from paddle.io import DataLoader +import paddle.nn.functional as F +from paddle import tensor + +from paddlenlp.utils.log import logger +from paddlenlp.data import Tuple, Pad +from paddlenlp.utils.tools import TimeCostAverage +from paddlenlp.transformers import LinearDecayWithWarmup +from paddlenlp.transformers import RobertaModel, RobertaTokenizer +from paddlenlp.transformers import ErnieModel, ErnieForSequenceClassification, ErnieTokenizer +from paddlenlp.transformers.distill_utils import to_distill, calc_minilm_loss_multi_relation + +MODEL_CLASSES = { + "roberta": (RobertaModel, RobertaTokenizer), + "ernie": (ErnieForSequenceClassification, ErnieTokenizer) +} + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_type", + default="ernie", + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--teacher_model_type", + default="ernie", + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--student_model_name_or_path", + default=None, + type=str, + required=False, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--teacher_model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model.") + parser.add_argument( + "--input_dir", + default=None, + type=str, + required=True, + help="The input directory where the data will be read from.", ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--learning_rate", + default=6e-4, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--num_layers", + default=6, + type=int, + help="Number layers of student model.", ) + parser.add_argument( + "--teacher_layer_index", + default=19, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--student_layer_index", + default=5, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--num_train_epochs", + default=3, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--logging_steps", + type=int, + default=100, + help="Log every X updates steps.") + parser.add_argument( + "--save_steps", + type=int, + default=100, + help="Save checkpoint every X updates steps.") + parser.add_argument( + "--batch_size", + default=512, + type=int, + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--num_relation_heads", + default=64, + type=int, + help="The number of relation heads is 48 and 64 for base and large-size teacher model.", + ) + parser.add_argument("--beta", default=0.0, type=float, help="0.0 usually") + parser.add_argument("--alpha", default=0.0, type=float, help="0.0 usually") + parser.add_argument( + "--weight_decay", + default=0.01, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--warmup_steps", + default=-1, + type=int, + help="Linear warmup over warmup_steps. If > 0: Override warmup_proportion" + ) + parser.add_argument( + "--warmup_proportion", + default=0.01, + type=float, + help="Linear warmup proportion over total steps.") + parser.add_argument( + "--adam_epsilon", + default=1e-8, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--max_steps", + default=400000, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="random seed for initialization") + parser.add_argument( + "--device", + default="gpu", + type=str, + help="The device to select to train the model, is must be cpu/gpu/xpu.") + parser.add_argument( + "--use_amp", + type=distutils.util.strtobool, + default=False, + help="Enable mixed precision training.") + parser.add_argument( + "--scale_loss", + type=float, + default=2**15, + help="The value of scale_loss for fp16.") + args = parser.parse_args() + return args + + +def set_seed(args): + random.seed(args.seed + paddle.distributed.get_rank()) + np.random.seed(args.seed + paddle.distributed.get_rank()) + paddle.seed(args.seed + paddle.distributed.get_rank()) + + +class WorkerInitObj(object): + def __init__(self, seed): + self.seed = seed + + def __call__(self, id): + np.random.seed(seed=self.seed + id) + random.seed(self.seed + id) + + +def create_pretraining_dataset(input_file, shared_list, args, worker_init, + tokenizer): + train_data = PretrainingDataset( + input_file=input_file, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length) + # files have been sharded, no need to dispatch again + train_batch_sampler = paddle.io.BatchSampler( + train_data, batch_size=args.batch_size, shuffle=True) + + # DataLoader cannot be pickled because of its place. + # If it can be pickled, use global function instead of lambda and use + # ProcessPoolExecutor instead of ThreadPoolExecutor to prefetch. + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + ): fn(samples) + + train_data_loader = DataLoader( + dataset=train_data, + batch_sampler=train_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + worker_init_fn=worker_init, + return_list=True) + return train_data_loader, input_file + + +class PretrainingDataset(paddle.io.Dataset): + def __init__(self, input_file, tokenizer, max_seq_length): + self.input_file = input_file + f = open(input_file, 'r') + input_ids = [] + for i, line in enumerate(f): + line = line[:max_seq_length] + tokenized_example = tokenizer(line, max_seq_len=max_seq_length) + input_ids.append(tokenized_example['input_ids']) + self.inputs = np.asarray(input_ids) + f.close() + + def __len__(self): + 'Denotes the total number of samples' + return len(self.inputs) + + def __getitem__(self, index): + input_ids = [np.asarray(self.inputs[index])] + return input_ids + + +def do_train(args): + paddle.set_device(args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + + worker_init = WorkerInitObj(args.seed + paddle.distributed.get_rank()) + args.model_type = args.model_type.lower() + + # For teacher + teacher_model_class, tokenizer_class = MODEL_CLASSES[ + args.teacher_model_type] + tokenizer = tokenizer_class.from_pretrained(args.teacher_model_name_or_path) + + # For student + model_class, _ = MODEL_CLASSES[args.model_type] + if args.num_layers == 6: + ernie = ErnieModel( + vocab_size=tokenizer.vocab_size, + num_hidden_layers=6, + hidden_act='relu', + intermediate_size=3072, + hidden_size=768) # layer: 6 + elif args.num_layers == 4: + ernie = ErnieModel( + vocab_size=tokenizer.vocab_size, + num_hidden_layers=4, + hidden_act='relu', + intermediate_size=1024, + hidden_size=256, + num_attention_heads=16) # layer: 4 + else: + ernie = ErnieModel( + vocab_size=tokenizer.vocab_size, + num_hidden_layers=2, + hidden_act='relu', + hidden_size=128, + intermediate_size=512) # layer: 2 + student = model_class(ernie) + + teacher = teacher_model_class.from_pretrained( + args.teacher_model_name_or_path) + pad_token_id = 0 + + if paddle.distributed.get_world_size() > 1: + student = paddle.DataParallel(student, find_unused_parameters=True) + teacher = paddle.DataParallel(teacher, find_unused_parameters=True) + + num_training_steps = args.max_steps + + warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion + lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, + warmup) + + # Generate parameter names needed to perform weight decay. + # All bias and LayerNorm parameters are excluded. + decay_params = [ + p.name for n, p in student.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + beta1=0.9, + beta2=0.999, + epsilon=args.adam_epsilon, + parameters=student.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in decay_params, + grad_clip=paddle.nn.ClipGradByGlobalNorm(args.max_grad_norm)) + + if args.use_amp: + scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss) + pool = ThreadPoolExecutor(1) + + teacher = to_distill( + teacher, return_qkv=True, layer_index=args.teacher_layer_index) + student = to_distill( + student, return_qkv=True, layer_index=args.student_layer_index) + + global_step = 0 + tic_train = time.time() + for epoch in range(args.num_train_epochs): + files = [ + os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) + if os.path.isfile(os.path.join(args.input_dir, f)) + ] + files.sort() + num_files = len(files) + random.Random(args.seed + epoch).shuffle(files) + f_start_id = 0 + + shared_file_list = {} + + if paddle.distributed.get_world_size() > num_files: + remainder = paddle.distributed.get_world_size() % num_files + + data_file = files[( + f_start_id * paddle.distributed.get_world_size() + + paddle.distributed.get_rank() + remainder * f_start_id) % + num_files] + else: + data_file = files[(f_start_id * paddle.distributed.get_world_size() + + paddle.distributed.get_rank()) % num_files] + + previous_file = data_file + + train_data_loader, _ = create_pretraining_dataset( + data_file, shared_file_list, args, worker_init, tokenizer) + + # TODO(guosheng): better way to process single file + single_file = True if f_start_id + 1 == len(files) else False + + for f_id in range(f_start_id, len(files)): + if not single_file and f_id == f_start_id: + continue + if paddle.distributed.get_world_size() > num_files: + data_file = files[( + f_id * paddle.distributed.get_world_size() + + paddle.distributed.get_rank() + remainder * f_id) % + num_files] + else: + data_file = files[(f_id * paddle.distributed.get_world_size() + + paddle.distributed.get_rank()) % num_files] + previous_file = data_file + dataset_future = pool.submit(create_pretraining_dataset, data_file, + shared_file_list, args, worker_init, + tokenizer) + + kl_loss_fct = paddle.nn.KLDivLoss('sum') + train_cost_avg = TimeCostAverage() + total_samples = 0 + batch_start = time.time() + for step, batch in enumerate(train_data_loader): + global_step += 1 + input_ids = batch[0] + attention_mask = paddle.unsqueeze( + (input_ids == pad_token_id + ).astype(paddle.get_default_dtype()) * -1e9, + axis=[1, 2]) + with paddle.amp.auto_cast( + args.use_amp, + custom_white_list=["layer_norm", "gelu", "softmax"]): + student(input_ids) + with paddle.no_grad(): + teacher(input_ids) + # Q-Q relation + q_t, q_s = teacher.outputs.q, student.outputs.q + batch_size = q_t.shape[0] + pad_seq_len = q_t.shape[2] + loss_qr1, loss_qr2, loss_qr3 = calc_minilm_loss_multi_relation( + kl_loss_fct, q_s, q_t, attention_mask, + args.num_relation_heads, args.alpha, args.beta) + del q_t, q_s + # K-K relation + k_t, k_s = teacher.outputs.k, student.outputs.k + loss_kr1, loss_kr2, loss_kr3 = calc_minilm_loss_multi_relation( + kl_loss_fct, k_s, k_t, attention_mask, + args.num_relation_heads, args.alpha, args.beta) + del k_t, k_s + + # V-V relation + v_t, v_s = teacher.outputs.v, student.outputs.v + loss_vr1, loss_vr2, loss_vr3 = calc_minilm_loss_multi_relation( + kl_loss_fct, v_s, v_t, attention_mask, + args.num_relation_heads, args.alpha, args.beta) + + del v_t, v_s + + loss1 = (loss_qr1 + loss_kr1 + loss_vr1) + loss1 /= args.num_relation_heads * pad_seq_len * batch_size + + loss2 = loss_qr2 + loss_kr2 + loss_vr2 + loss2 /= args.num_relation_heads * pad_seq_len * batch_size + + loss3 = loss_qr3 + loss_kr3 + loss_vr3 + loss3 /= args.num_relation_heads * pad_seq_len * batch_size + loss = (1 - args.alpha - args.beta + ) * loss1 + loss2 * args.alpha + loss3 * args.beta + + if args.use_amp: + scaler.scale(loss).backward() + scaler.minimize(optimizer, loss) + else: + loss.backward() + + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + + total_samples += args.batch_size + train_run_cost = time.time() - batch_start + train_cost_avg.record(train_run_cost) + if global_step % args.logging_steps == 0: + logger.info( + "global step: %d, epoch: %d, batch: %d, loss: %f, loss1: %f, loss2: %f, loss3: %f," + "lr: %f, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec" + % (global_step, epoch, step, loss, loss1, loss2, loss3, + optimizer.get_lr(), train_cost_avg.get_average(), + total_samples / args.logging_steps, total_samples / + (args.logging_steps * train_cost_avg.get_average()))) + total_samples = 0 + train_cost_avg.reset() + if global_step % args.save_steps == 0 or global_step == num_training_steps: + if paddle.distributed.get_rank() == 0: + output_dir = os.path.join(args.output_dir, + "model_%d" % global_step) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # need better way to get inner model of DataParallel + model_to_save = student._layers if isinstance( + student, paddle.DataParallel) else student + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + paddle.save( + optimizer.state_dict(), + os.path.join(output_dir, "model_state.pdopt")) + if global_step >= args.max_steps: + del train_data_loader + return + batch_start = time.time() + + del train_data_loader + train_data_loader, data_file = dataset_future.result(timeout=None) + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + do_train(args) diff --git a/examples/model_compression/PP-MiniLM/general_distill/run.sh b/examples/model_compression/PP-MiniLM/general_distill/run.sh new file mode 100644 index 000000000000..12ae63f7971b --- /dev/null +++ b/examples/model_compression/PP-MiniLM/general_distill/run.sh @@ -0,0 +1,55 @@ +set -eux + +unset CUDA_VISIBLE_DEVICES + +bs=128 +maxlen=128 +numH=64 +lr=6e-4 +maxStep=400000 +warmStep=4000 +wd=1e-2 + +teacher=roberta +teacherModel=roberta-wwm-ext-large + +alpha=0 +beta=1.0 +mode=hardest +use_amp=True +teacher_layer_index=19 +student_layer_index=5 +num_layers=6 + +hp_config=bs${bs}_maxlen${maxlen}_lr${lr}_wd${wd}_numH${numH}_maxStep${maxStep}_warmStep${warmStep}_adamW_maxnorm1p0_teacher_${teacherModel}_coldboot_teacher_vocab_index${teacher_layer_index}_4l-312d-batchbatch + +export PYTHONPATH=../../../../:$PYTHONPATH +output_dir="./pretrain_${hp_config}" + +mkdir -p ${output_dir} +cp ./general_distill.py ${output_dir}/ +cp ../../../../paddlenlp/transformers/distill_utils.py ${output_dir}/ + + +python3 -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" general_distill.py \ + --model_type ernie \ + --num_relation_heads ${numH} \ + --teacher_model_type ${teacher} \ + --teacher_layer_index ${teacher_layer_index} \ + --student_layer_index ${student_layer_index} \ + --teacher_model_name_or_path ${teacherModel} \ + --max_seq_length ${maxlen} \ + --num_layers ${num_layers} \ + --batch_size ${bs} \ + --learning_rate ${lr} \ + --logging_steps 20 \ + --max_steps ${maxStep} \ + --warmup_steps ${warmStep} \ + --save_steps 20000 \ + --weight_decay ${wd} \ + --output_dir ${output_dir} \ + --device gpu \ + --input_dir dataset/ \ + --use_amp ${use_amp} \ + --alpha ${alpha} \ + --beta ${beta} \ diff --git a/examples/model_compression/PP-MiniLM/inference/infer.py b/examples/model_compression/PP-MiniLM/inference/infer.py new file mode 100644 index 000000000000..9cc8422119a5 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/inference/infer.py @@ -0,0 +1,311 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time +from functools import partial +import numpy as np + +import paddle +from paddle import inference +from paddlenlp.datasets import load_dataset +from paddlenlp.data import Stack, Tuple, Pad +from paddle.metric import Metric, Accuracy, Precision, Recall +from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman + +from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer + +METRIC_CLASSES = { + "afqmc": Accuracy, + "tnews": Accuracy, + "iflytek": Accuracy, + "ocnli": Accuracy, + "cmnli": Accuracy, + "cluewsc2020": Accuracy, + "csl": Accuracy, +} + +MODEL_CLASSES = {"ernie": (ErnieForSequenceClassification, ErnieTokenizer), } + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=512, + is_test=False): + """convert a glue example into necessary features""" + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + label = example['label'] + label = np.array([label], dtype=label_dtype) + # Convert raw text to feature + if 'sentence' in example: + example = tokenizer(example['sentence'], max_seq_len=max_seq_length) + elif 'sentence1' in example: + example = tokenizer( + example['sentence1'], + text_pair=example['sentence2'], + max_seq_len=max_seq_length) + elif 'keyword' in example: # CSL + sentence1 = " ".join(example['keyword']) + example = tokenizer( + sentence1, text_pair=example['abst'], max_seq_len=max_seq_length) + elif 'target' in example: # wsc + text, query, pronoun, query_idx, pronoun_idx = example['text'], example[ + 'target']['span1_text'], example['target']['span2_text'], example[ + 'target']['span1_index'], example['target']['span2_index'] + text_list = list(text) + assert text[pronoun_idx:(pronoun_idx + len(pronoun) + )] == pronoun, "pronoun: {}".format(pronoun) + assert text[query_idx:(query_idx + len(query) + )] == query, "query: {}".format(query) + if pronoun_idx > query_idx: + text_list.insert(query_idx, "_") + text_list.insert(query_idx + len(query) + 1, "_") + text_list.insert(pronoun_idx + 2, "[") + text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") + else: + text_list.insert(pronoun_idx, "[") + text_list.insert(pronoun_idx + len(pronoun) + 1, "]") + text_list.insert(query_idx + 2, "_") + text_list.insert(query_idx + len(query) + 2 + 1, "_") + text = "".join(text_list) + example = tokenizer(text, max_seq_len=max_seq_length) + + if not is_test: + return example['input_ids'], example['token_type_ids'], label + else: + return example['input_ids'], example['token_type_ids'] + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--task_name", + default='afqmc', + type=str, + help="The name of the task to perform predict, selected in the list: " + + ", ".join(METRIC_CLASSES.keys()), ) + parser.add_argument( + "--model_type", + default='ernie', + type=str, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--tokenizer_path", + default='../general_distill/ernie-batchbatch-50w_400000/', + type=str, + required=True, + help="The directory for tokenizer.", ) + parser.add_argument( + "--model_path", + default='./quant_models/model', + type=str, + required=True, + help="The path prefix of inference model to be used.", ) + parser.add_argument( + "--device", + default="gpu", + choices=["gpu", "cpu", "xpu"], + help="Device selected for inference.", ) + parser.add_argument( + "--batch_size", + default=32, + type=int, + help="Batch size for predict.", ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--use_trt", + action='store_true', + help="Whether to use inference engin TensorRT.", ) + + parser.add_argument( + "--collect_shape", + action='store_true', + help="Whether collect shape range info.", ) + parser.add_argument( + "--int8", + action='store_true', + help="Whether int8 inference.", ) + args = parser.parse_args() + return args + + +@paddle.no_grad() +def evaluate(outputs, metric, data_loader): + metric.reset() + for i, batch in enumerate(data_loader): + input_ids, segment_ids, labels = batch + logits = paddle.to_tensor(outputs[i][0]) + correct = metric.compute(logits, labels) + metric.update(correct) + res = metric.accumulate() + print("acc: %s, " % res, end='') + + +class Predictor(object): + def __init__(self, predictor, input_handles, output_handles): + self.predictor = predictor + self.input_handles = input_handles + self.output_handles = output_handles + + @classmethod + def create_predictor(cls, args): + config = paddle.inference.Config(args.model_path + ".pdmodel", + args.model_path + ".pdiparams") + if args.device == "gpu": + # set GPU configs accordingly + config.enable_use_gpu(100, 0) + cls.device = paddle.set_device("gpu") + elif args.device == "cpu": + # set CPU configs accordingly, + # such as enable_mkldnn, set_cpu_math_library_num_threads + config.disable_gpu() + cls.device = paddle.set_device("cpu") + elif args.device == "xpu": + # set XPU configs accordingly + config.enable_xpu(100) + config.switch_use_feed_fetch_ops(False) # could be deleted + if args.use_trt: + if args.int8: + print("int8") + config.enable_tensorrt_engine( + workspace_size=1 << 30, + precision_mode=inference.PrecisionType.Int8, + max_batch_size=args.batch_size, + min_subgraph_size=5, + use_static=False, + use_calib_mode=False) + else: + config.enable_tensorrt_engine( + workspace_size=1 << 30, + precision_mode=inference.PrecisionType.Float32, + max_batch_size=args.batch_size, + min_subgraph_size=5, + use_static=False, + use_calib_mode=False) + print("Enable TensorRT is: {}".format( + config.tensorrt_engine_enabled())) + if args.collect_shape: + config.collect_shape_range_info( + os.path.join( + os.path.dirname(args.model_path), args.task_name + + '_shape_range_info.pbtxt')) + else: + config.enable_tuned_tensorrt_dynamic_shape( + os.path.join( + os.path.dirname(args.model_path), + args.task_name + "_shape_range_info.pbtxt"), True) + predictor = paddle.inference.create_predictor(config) + input_handles = [ + predictor.get_input_handle(name) + for name in predictor.get_input_names() + ] + output_handles = [ + predictor.get_output_handle(name) + for name in predictor.get_output_names() + ] + cls.time = 0.0 + + return cls(predictor, input_handles, output_handles) + + def predict_batch(self, data): + for input_field, input_handle in zip(data, self.input_handles): + input_handle.copy_from_cpu(input_field.numpy() if isinstance( + input_field, paddle.Tensor) else input_field) + time1 = time.time() + self.predictor.run() + paddle.fluid.core._cuda_synchronize(self.device) + self.time += time.time() - time1 + output = [ + output_handle.copy_to_cpu() for output_handle in self.output_handles + ] + + return output + + def predict(self, dataset, collate_fn, args, batch_size=1): + metric = METRIC_CLASSES[args.task_name]() + batch_sampler = paddle.io.BatchSampler( + dataset, batch_size=batch_size, shuffle=False) + data_loader = paddle.io.DataLoader( + dataset=dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=0, + return_list=True) + outputs = [] + metric.reset() + for i, data in enumerate(data_loader): + # warmup for performance test + if i < 20: + continue + if len(data) == 2: + output = self.predict_batch(data) + else: + output = self.predict_batch([data[0], data[1]]) + logits = paddle.to_tensor(output) + correct = metric.compute(logits, data[2]) + metric.update(correct) + outputs.append(output) + if len(data) > 2: + res = metric.accumulate() + print("task name: %s, acc: %s, " % (args.task_name, res), end='') + print("time: ", self.time) + + return outputs + + +def main(): + paddle.seed(42) + args = parse_args() + + args.task_name = args.task_name.lower() + args.model_type = args.model_type.lower() + + predictor = Predictor.create_predictor(args) + + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + dev_ds = load_dataset('clue', args.task_name, splits='dev') + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path) + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=dev_ds.label_list, + max_seq_length=args.max_seq_length, + is_test=False) + + dev_ds = dev_ds.map(trans_func, lazy=True) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment + Stack(dtype="int64" if dev_ds.label_list else "float32") # label + ): fn(samples) + outputs = predictor.predict( + dev_ds, batch_size=args.batch_size, collate_fn=batchify_fn, args=args) + + +if __name__ == "__main__": + main() diff --git a/examples/model_compression/PP-MiniLM/inference/infer.sh b/examples/model_compression/PP-MiniLM/inference/infer.sh new file mode 100644 index 000000000000..b50f23e37fd7 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/inference/infer.sh @@ -0,0 +1,26 @@ +echo 原来的模型 +python infer.py --task_name tnews --model_path tnews/float --use_trt --collect_shape +python infer.py --task_name tnews --model_path tnews/float --use_trt +python infer.py --task_name tnews --model_path tnews/float --use_trt +python infer.py --task_name tnews --model_path tnews/float --use_trt +python infer.py --task_name tnews --model_path tnews/float --use_trt +python infer.py --task_name tnews --model_path tnews/float --use_trt + + +echo 裁剪后 +python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt --collect_shape +python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt +python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt +python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt +python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt +python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt + + +echo int8推理 +python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt --collect_shape +python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt +python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt +python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt +python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt +python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt + diff --git a/examples/model_compression/PP-MiniLM/inference/infer_all.sh b/examples/model_compression/PP-MiniLM/inference/infer_all.sh new file mode 100644 index 000000000000..ef1d9bfbfcfb --- /dev/null +++ b/examples/model_compression/PP-MiniLM/inference/infer_all.sh @@ -0,0 +1,12 @@ +for task in afqmc tnews iflytek cmnli ocnli cluewsc2020 csl +do + for bs in 4 8 + do + for algo in abs_max avg hist mse + do + python infer.py --task_name ${task} --model_path ../quantization/${task}_quant_models/${algo}${bs}/int8 --int8 --use_trt --collect_shape + python infer.py --task_name ${task} --model_path ../quantization/${task}_quant_models/${algo}${bs}/int8 --int8 --use_trt + echo this is ${task}, ${algo}, ${bs} + done + done +done diff --git a/examples/model_compression/PP-MiniLM/ofa/export.sh b/examples/model_compression/PP-MiniLM/ofa/export.sh new file mode 100644 index 000000000000..74b7c3002db4 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/ofa/export.sh @@ -0,0 +1,7 @@ +MODEL_PATH=$1 +TASK_NAME=$2 +python export_model.py --model_type ernie \ + --model_name_or_path ${MODEL_PATH}/${TASK_NAME}/0.75/best_model \ + --sub_model_output_dir ${MODEL_PATH}/${TASK_NAME}/0.75/sub/ \ + --static_sub_model ${MODEL_PATH}/${TASK_NAME}/0.75/sub_static/float \ + --n_gpu 1 --width_mult 0.75 diff --git a/examples/model_compression/PP-MiniLM/ofa/export_all.sh b/examples/model_compression/PP-MiniLM/ofa/export_all.sh new file mode 100644 index 000000000000..8e295fca5c05 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/ofa/export_all.sh @@ -0,0 +1,12 @@ +MODEL_PATH=ofa_models + +for TASK_NAME in AFQMC TNEWS IFLYTEK CMNLI OCNLI CLUEWSC2020 CSL + +do + python export_model.py --model_type ernie \ + --model_name_or_path ${MODEL_PATH}/${TASK_NAME}/0.75/best_model \ + --sub_model_output_dir ${MODEL_PATH}/${TASK_NAME}/0.75/sub/ \ + --static_sub_model ${MODEL_PATH}/${TASK_NAME}/0.75/sub_static/float \ + --n_gpu 1 --width_mult 0.75 + +done diff --git a/examples/model_compression/PP-MiniLM/ofa/export_model.py b/examples/model_compression/PP-MiniLM/ofa/export_model.py new file mode 100644 index 000000000000..c9edb13e6301 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/ofa/export_model.py @@ -0,0 +1,226 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import math +import random +import time +import json +from functools import partial + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddlenlp.transformers import BertModel, BertForSequenceClassification, BertTokenizer +from paddlenlp.transformers import ErnieModel, ErnieForSequenceClassification, ErnieTokenizer +from paddlenlp.transformers import RobertaForSequenceClassification, RobertaTokenizer +from paddlenlp.utils.log import logger +from paddleslim.nas.ofa import OFA, utils +from paddleslim.nas.ofa.convert_super import Convert, supernet +from paddleslim.nas.ofa.layers import BaseBlock + +MODEL_CLASSES = { + "bert": (BertForSequenceClassification, BertTokenizer), + "ernie": (ErnieForSequenceClassification, ErnieTokenizer), +} + + +def ernie_forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None): + wtype = self.pooler.dense.fn.weight.dtype if hasattr( + self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype + if attention_mask is None: + attention_mask = paddle.unsqueeze( + (input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2]) + embedding_output = self.embeddings(input_ids, token_type_ids, position_ids) + encoded_layer = self.encoder(embedding_output, attention_mask) + pooled_output = self.pooler(encoded_layer) + + return encoded_layer, pooled_output + + +ErnieModel.forward = ernie_forward + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--sub_model_output_dir", + default=None, + type=str, + required=True, + help="The output directory where the sub model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--static_sub_model", + default=None, + type=str, + help="The output directory where the sub static model will be written. If set to None, not export static model", + ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--n_gpu", + type=int, + default=1, + help="number of gpus to use, 0 for cpu.") + parser.add_argument( + '--width_mult', + type=float, + default=1.0, + help="width mult you want to export") + parser.add_argument( + '--depth_mult', + type=float, + default=1.0, + help="depth mult you want to export") + args = parser.parse_args() + return args + + +def export_static_model(model, model_path, max_seq_length): + input_shape = [ + paddle.static.InputSpec( + shape=[None, max_seq_length], dtype='int64'), + paddle.static.InputSpec( + shape=[None, max_seq_length], dtype='int64') + ] + net = paddle.jit.to_static(model, input_spec=input_shape) + paddle.jit.save(net, model_path) + + +def do_train(args): + paddle.set_device("gpu" if args.n_gpu else "cpu") + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + config_path = os.path.join(args.model_name_or_path, 'model_config.json') + cfg_dict = dict(json.loads(open(config_path).read())) + + kept_layers_index = {} + if args.depth_mult < 1.0: + depth = round(cfg_dict["init_args"][0]['num_hidden_layers'] * + args.depth_mult) + cfg_dict["init_args"][0]['num_hidden_layers'] = depth + for idx, i in enumerate(range(1, depth + 1)): + kept_layers_index[idx] = math.floor(i / args.depth_mult) - 1 + + os.rename(config_path, config_path + '_bak') + with open(config_path, "w", encoding="utf-8") as f: + f.write(json.dumps(cfg_dict, ensure_ascii=False)) + + num_labels = cfg_dict['num_classes'] + + model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_labels) + + origin_model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_labels) + + os.rename(config_path + '_bak', config_path) + + sp_config = supernet(expand_ratio=[1.0, args.width_mult]) + model = Convert(sp_config).convert(model) + + ofa_model = OFA(model) + + sd = paddle.load( + os.path.join(args.model_name_or_path, 'model_state.pdparams')) + + if len(kept_layers_index) == 0: + ofa_model.model.set_state_dict(sd) + else: + for name, params in ofa_model.model.named_parameters(): + if 'encoder' not in name: + params.set_value(sd[name]) + else: + idx = int(name.strip().split('.')[3]) + mapping_name = name.replace( + '.' + str(idx) + '.', + '.' + str(kept_layers_index[idx]) + '.') + params.set_value(sd[mapping_name]) + + best_config = utils.dynabert_config(ofa_model, args.width_mult) + for name, sublayer in ofa_model.model.named_sublayers(): + if isinstance(sublayer, paddle.nn.MultiHeadAttention): + sublayer.num_heads = int(args.width_mult * sublayer.num_heads) + + #for name, params in origin_model.named_parameters(): + # print(name, params.name) + origin_model_new = ofa_model.export( + best_config, + input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]], + input_dtypes=['int64', 'int64'], + origin_model=origin_model) + for name, sublayer in origin_model_new.named_sublayers(): + if isinstance(sublayer, paddle.nn.MultiHeadAttention): + sublayer.num_heads = int(args.width_mult * sublayer.num_heads) + + output_dir = os.path.join(args.sub_model_output_dir, + "model_width_%.5f" % args.width_mult) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + model_to_save = origin_model_new + model_to_save.save_pretrained(output_dir) + #print(origin_model_new.state_dict().keys()) + #print("=====================") + #for name, params in origin_model_new.named_parameters(): + # print(name, params.name) + if args.static_sub_model != None: + export_static_model(origin_model_new, args.static_sub_model, + args.max_seq_length) + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + do_train(args) diff --git a/examples/model_compression/PP-MiniLM/ofa/run_ofa.py b/examples/model_compression/PP-MiniLM/ofa/run_ofa.py new file mode 100644 index 000000000000..94a431d3bee7 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/ofa/run_ofa.py @@ -0,0 +1,541 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import random +import time +import math +from functools import partial + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.io import DataLoader +from paddle.metric import Accuracy, Precision, Recall + +from paddlenlp.data import Stack, Tuple, Pad, Dict +from paddlenlp.datasets import load_dataset +from paddlenlp.data.sampler import SamplerHelper +from paddlenlp.transformers import TinyBertModel, TinyBertForSequenceClassification, TinyBertTokenizer +from paddlenlp.transformers import LinearDecayWithWarmup +from paddlenlp.utils.log import logger +from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman +from paddlenlp.transformers import LinearDecayWithWarmup +from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer +from paddlenlp.transformers import TinyBertForSequenceClassification, TinyBertTokenizer +from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer, ErnieModel +from paddlenlp.transformers import RobertaForSequenceClassification, RobertaTokenizer +from paddleslim.nas.ofa import OFA, DistillConfig, utils +from paddleslim.nas.ofa.utils import nlp_utils +from paddleslim.nas.ofa.convert_super import Convert, supernet + +METRIC_CLASSES = { + "afqmc": Accuracy, + "tnews": Accuracy, + "iflytek": Accuracy, + "ocnli": Accuracy, + "cmnli": Accuracy, + "cluewsc2020": Accuracy, + "csl": Accuracy, +} + +MODEL_CLASSES = { + "bert": (BertForSequenceClassification, BertTokenizer), + "roberta": (RobertaForSequenceClassification, RobertaTokenizer), + "tinybert": (TinyBertForSequenceClassification, TinyBertTokenizer), + "ernie": (ErnieForSequenceClassification, ErnieTokenizer), +} + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--task_name", + default=None, + type=str, + required=True, + help="The name of the task to train selected in the list: " + + ", ".join(METRIC_CLASSES.keys()), ) + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--glue_dir", + default="/root/.paddlenlp/datasets/Clue/", + type=str, + required=False, + help="The Glue directory.", ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--batch_size", + default=8, + type=int, + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--learning_rate", + default=5e-5, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--weight_decay", + default=0.0, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--adam_epsilon", + default=1e-8, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--lambda_logit", + default=1.0, + type=float, + help="lambda for logit loss.") + parser.add_argument( + "--num_train_epochs", + default=3, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--warmup_steps", + default=0, + type=int, + help="Linear warmup over warmup_steps.") + parser.add_argument( + "--warmup_proportion", + default=0.1, + type=float, + help="Linear warmup proportion over total steps.") + parser.add_argument( + "--logging_steps", + type=int, + default=500, + help="Log every X updates steps.") + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Save checkpoint every X updates steps.") + parser.add_argument( + "--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument( + "--device", + default="gpu", + type=str, + choices=["gpu", "cpu", "xpu"], + help="The device to select to train the model, is must be cpu/gpu/xpu.") + parser.add_argument( + '--width_mult_list', + nargs='+', + type=float, + default=[1.0, 5 / 6, 2 / 3, 0.5], + help="width mult in compress") + args = parser.parse_args() + return args + + +def set_seed(args): + # Use the same data seed(for data shuffle) for all procs to guarantee data + # consistency after sharding. + random.seed(args.seed) + np.random.seed(args.seed) + # Maybe different op seeds(for dropout) for different procs is better. By: + # `paddle.seed(args.seed + paddle.distributed.get_rank())` + paddle.seed(args.seed) + + +@paddle.no_grad() +def evaluate(model, metric, data_loader, width_mult, student=False): + model.eval() + metric.reset() + for i, batch in enumerate(data_loader): + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids, attention_mask=[None, None]) + #print(logits) + #sys.exit() + #if student: + #print(batch) + #print(logits) + #import pdb; pdb.set_trace() # ofa_model和保存下来的超网络model不对 + + #sys.exit() + if isinstance(logits, tuple): + logits = logits[0] + correct = metric.compute(logits, labels) + metric.update(correct) + + res = metric.accumulate() + print("width_mult: %s, acc: %s, " % (str(width_mult), res), end='') + model.train() + return res + + +### monkey patch for bert forward to accept [attention_mask, head_mask] as attention_mask +def ernie_forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=[None, None]): + wtype = self.pooler.dense.fn.weight.dtype if hasattr( + self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype + if attention_mask[0] is None: + attention_mask[0] = paddle.unsqueeze( + (input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2]) + embedding_output = self.embeddings(input_ids, token_type_ids, position_ids) + encoded_layer = self.encoder(embedding_output, attention_mask) + pooled_output = self.pooler(encoded_layer) + + return encoded_layer, pooled_output + + +ErnieModel.forward = ernie_forward + + +### reorder weights according head importance and neuron importance +def reorder_neuron_head(model, head_importance, neuron_importance): + # reorder heads and ffn neurons + for layer, current_importance in enumerate(neuron_importance): + # reorder heads + idx = paddle.argsort(head_importance[layer], descending=True) + nlp_utils.reorder_head(model.ernie.encoder.layers[layer].self_attn, idx) + # reorder neurons + idx = paddle.argsort( + paddle.to_tensor(current_importance), descending=True) + nlp_utils.reorder_neuron( + model.ernie.encoder.layers[layer].linear1.fn, idx, dim=1) + nlp_utils.reorder_neuron( + model.ernie.encoder.layers[layer].linear2.fn, idx, dim=0) + + +def soft_cross_entropy(inp, target): + inp_likelihood = F.log_softmax(inp, axis=-1) + target_prob = F.softmax(target, axis=-1) + return -1. * paddle.mean(paddle.sum(inp_likelihood * target_prob, axis=-1)) + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=512, + is_test=False): + """convert a glue example into necessary features""" + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + label = example['label'] + label = np.array([label], dtype=label_dtype) + # Convert raw text to feature + if 'sentence' in example: + example = tokenizer(example['sentence'], max_seq_len=max_seq_length) + elif 'sentence1' in example: + example = tokenizer( + example['sentence1'], + text_pair=example['sentence2'], + max_seq_len=max_seq_length) + elif 'keyword' in example: # CSL + sentence1 = " ".join(example['keyword']) + example = tokenizer( + sentence1, text_pair=example['abst'], max_seq_len=max_seq_length) + elif 'target' in example: # wsc + text, query, pronoun, query_idx, pronoun_idx = example['text'], example[ + 'target']['span1_text'], example['target']['span2_text'], example[ + 'target']['span1_index'], example['target']['span2_index'] + text_list = list(text) + # print(text) + assert text[pronoun_idx:(pronoun_idx + len(pronoun) + )] == pronoun, "pronoun: {}".format(pronoun) + assert text[query_idx:(query_idx + len(query) + )] == query, "query: {}".format(query) + if pronoun_idx > query_idx: + text_list.insert(query_idx, "_") + text_list.insert(query_idx + len(query) + 1, "_") + text_list.insert(pronoun_idx + 2, "[") + text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") + else: + text_list.insert(pronoun_idx, "[") + text_list.insert(pronoun_idx + len(pronoun) + 1, "]") + text_list.insert(query_idx + 2, "_") + text_list.insert(query_idx + len(query) + 2 + 1, "_") + text = "".join(text_list) + # print(text) + example = tokenizer(text, max_seq_len=max_seq_length) + + if not is_test: + return example['input_ids'], example['token_type_ids'], label + else: + return example['input_ids'], example['token_type_ids'] + + +def do_train(args): + paddle.set_device(args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + + args.task_name = args.task_name.lower() + metric_class = METRIC_CLASSES[args.task_name] + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + train_ds = load_dataset('clue', args.task_name, splits='train') + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=train_ds.label_list, + max_seq_length=args.max_seq_length) + train_ds = train_ds.map(trans_func, lazy=True) + train_batch_sampler = paddle.io.DistributedBatchSampler( + train_ds, batch_size=args.batch_size, shuffle=True) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment + Stack(dtype="int64" if train_ds.label_list else "float32") # label + ): fn(samples) + + train_data_loader = DataLoader( + dataset=train_ds, + batch_sampler=train_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + dev_ds = load_dataset('clue', args.task_name, splits='dev') + dev_ds = dev_ds.map(trans_func, lazy=True) + dev_batch_sampler = paddle.io.BatchSampler( + dev_ds, batch_size=args.batch_size, shuffle=False) + dev_data_loader = DataLoader( + dataset=dev_ds, + batch_sampler=dev_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + num_labels = 1 if train_ds.label_list == None else len(train_ds.label_list) + + model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_labels) + + # Step1: Initialize a dictionary to save the weights from the origin BERT model. + origin_weights = model.state_dict() + + # Step2: Convert origin model to supernet. + sp_config = supernet( + expand_ratio=[1.0]) #expand_ratio=args.width_mult_list) + model = Convert(sp_config).convert(model) + # Use weights saved in the dictionary to initialize supernet. + utils.set_state_dict(model, origin_weights) + del origin_weights + + super_sd = paddle.load( + os.path.join(args.model_name_or_path, 'model_state.pdparams')) + model.set_state_dict(super_sd) + + # Step3: Define teacher model. + #print(args.model_name_or_path, "**************") + teacher_model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_labels) + + # Step4: Config about distillation. + mapping_layers = ['ernie.embeddings'] + for idx in range(model.ernie.config['num_hidden_layers']): + mapping_layers.append('ernie.encoder.layers.{}'.format(idx)) + + default_distill_config = { + 'lambda_distill': 0.1, + 'teacher_model': teacher_model, + 'mapping_layers': mapping_layers, + } + distill_config = DistillConfig(**default_distill_config) + + # Step5: Config in supernet training. + ofa_model = OFA(model, + distill_config=distill_config, + elastic_order=['width']) + + criterion = paddle.nn.loss.CrossEntropyLoss( + ) if train_ds.label_list else paddle.nn.loss.MSELoss() + + metric = metric_class() + + #### Step6: Calculate the importance of neurons and head, + #### and then reorder them according to the importance. + head_importance, neuron_importance = nlp_utils.compute_neuron_head_importance( + args.task_name, + ofa_model.model, + dev_data_loader, + loss_fct=criterion, + num_layers=model.ernie.config['num_hidden_layers'], + num_heads=model.ernie.config['num_attention_heads']) + reorder_neuron_head(ofa_model.model, head_importance, neuron_importance) + + if paddle.distributed.get_world_size() > 1: + ofa_model.model = paddle.DataParallel(ofa_model.model) + + if args.max_steps > 0: + num_training_steps = args.max_steps + num_train_epochs = math.ceil(num_training_steps / + len(train_data_loader)) + else: + num_training_steps = len(train_data_loader) * args.num_train_epochs + num_train_epochs = args.num_train_epochs + + warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion + + lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, + warmup) + + # Generate parameter names needed to perform weight decay. + # All bias and LayerNorm parameters are excluded. + decay_params = [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + beta1=0.9, + beta2=0.999, + epsilon=args.adam_epsilon, + parameters=model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in decay_params, + grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm)) + + global_step = 0 + tic_train = time.time() + best_res = 0.0 + for epoch in range(num_train_epochs): + # Step7: Set current epoch and task. + ofa_model.set_epoch(epoch) + ofa_model.set_task('width') + + for step, batch in enumerate(train_data_loader): + global_step += 1 + input_ids, segment_ids, labels = batch + + for width_mult in args.width_mult_list: + # Step8: Broadcast supernet config from width_mult, + # and use this config in supernet training. + net_config = utils.dynabert_config(ofa_model, width_mult) + ofa_model.set_net_config(net_config) + logits, teacher_logits = ofa_model( + input_ids, segment_ids, attention_mask=[None, None]) + rep_loss = ofa_model.calc_distill_loss() + if args.task_name == 'sts-b': + logit_loss = paddle.zeros(shape=[1], dtype='float32') + else: + logit_loss = soft_cross_entropy(logits, + teacher_logits.detach()) + loss = rep_loss + args.lambda_logit * logit_loss + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + + if global_step % args.logging_steps == 0: + logger.info( + "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s" + % (global_step, epoch, step, loss, + args.logging_steps / (time.time() - tic_train))) + tic_train = time.time() + + if global_step % args.save_steps == 0 or global_step == num_training_steps: + tic_eval = time.time() + evaluate(teacher_model, metric, dev_data_loader, width_mult=100) + print("eval done total : %s s" % (time.time() - tic_eval)) + for idx, width_mult in enumerate(args.width_mult_list): + net_config = utils.dynabert_config(ofa_model, width_mult) + #print(net_config) + ofa_model.set_net_config(net_config) + tic_eval = time.time() + res = evaluate(ofa_model, metric, dev_data_loader, + width_mult) + print("eval done total : %s s" % (time.time() - tic_eval)) + + if best_res < res: + output_dir = args.output_dir + #output_dir = os.path.join(args.output_dir, + #"model_%d" % global_step) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + print("res saved", res) + #res = evaluate(ofa_model, metric, + #dev_data_loader, width_mult, True) + best_res = res + #sys.exit() + if global_step >= num_training_steps: + print("best_res: ", best_res) + return + print("best_res: ", best_res) + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + do_train(args) diff --git a/examples/model_compression/PP-MiniLM/ofa/run_ofa.sh b/examples/model_compression/PP-MiniLM/ofa/run_ofa.sh new file mode 100644 index 000000000000..011ef885d3fd --- /dev/null +++ b/examples/model_compression/PP-MiniLM/ofa/run_ofa.sh @@ -0,0 +1,19 @@ +export TASK_NAME=$1 +export LR=$2 +export BATCH_SIZE=$3 +export PRE_EPOCHS=$4 +export SEQ_LEN=$5 +export CUDA_VISIBLE_DEVICES=$6 +export STUDENT_DIR=$7 + +python -u ./run_ofa.py --model_type ernie \ + --model_name_or_path ${STUDENT_DIR} \ + --task_name $TASK_NAME --max_seq_length ${SEQ_LEN} \ + --batch_size ${BATCH_SIZE} \ + --learning_rate ${LR} \ + --num_train_epochs ${PRE_EPOCHS} \ + --logging_steps 100 \ + --save_steps 100 \ + --output_dir ./ofa_models/$TASK_NAME/0.75/best_model/ \ + --device gpu \ + --width_mult_list 0.75 diff --git a/examples/model_compression/PP-MiniLM/pp-minilm.png b/examples/model_compression/PP-MiniLM/pp-minilm.png new file mode 100644 index 000000000000..148c7406a783 Binary files /dev/null and b/examples/model_compression/PP-MiniLM/pp-minilm.png differ diff --git a/examples/model_compression/PP-MiniLM/quantization/quant_all.sh b/examples/model_compression/PP-MiniLM/quantization/quant_all.sh new file mode 100644 index 000000000000..4414c3d6a621 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/quantization/quant_all.sh @@ -0,0 +1,8 @@ +MODEL_DIR=../ofa/ofa_models/ +for task in AFQMC TNEWS IFLYTEK CMNLI OCNLI CLUEWSC2020 CSL + +do + + python quant_post.py --task_name ${task} --input_dir ${MODEL_DIR}/${task}/0.75/sub_static + +done diff --git a/examples/model_compression/PP-MiniLM/quantization/quant_post.py b/examples/model_compression/PP-MiniLM/quantization/quant_post.py new file mode 100644 index 000000000000..c8d05612a166 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/quantization/quant_post.py @@ -0,0 +1,167 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import sys +import os +import time +import argparse +from functools import partial +import numpy as np + +import paddle +from paddle.metric import Accuracy + +import paddlenlp +import paddleslim +from paddlenlp.data import Stack, Tuple, Pad, Dict +from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer + +parser = argparse.ArgumentParser() + +parser.add_argument( + "--task_name", type=str, default="afqmc", required=False, help="task_name") +parser.add_argument( + "--input_dir", + type=str, + default="afqmc", + required=False, + help="input task model dire") + +args = parser.parse_args() + +METRIC_CLASSES = { + "afqmc": Accuracy, + "tnews": Accuracy, + "iflytek": Accuracy, + "ocnli": Accuracy, + "cmnli": Accuracy, + "cluewsc2020": Accuracy, + "csl": Accuracy, +} + +MODEL_CLASSES = {"ernie": (ErnieForSequenceClassification, ErnieTokenizer), } + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=512, + is_test=False): + """convert a glue example into necessary features""" + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + label = example['label'] + label = np.array([label], dtype=label_dtype) + # Convert raw text to feature + if 'sentence' in example: + example = tokenizer(example['sentence'], max_seq_len=max_seq_length) + elif 'sentence1' in example: + example = tokenizer( + example['sentence1'], + text_pair=example['sentence2'], + max_seq_len=max_seq_length) + elif 'keyword' in example: # CSL + sentence1 = " ".join(example['keyword']) + example = tokenizer( + sentence1, text_pair=example['abst'], max_seq_len=max_seq_length) + elif 'target' in example: # wsc + text, query, pronoun, query_idx, pronoun_idx = example['text'], example[ + 'target']['span1_text'], example['target']['span2_text'], example[ + 'target']['span1_index'], example['target']['span2_index'] + text_list = list(text) + assert text[pronoun_idx:(pronoun_idx + len(pronoun) + )] == pronoun, "pronoun: {}".format(pronoun) + assert text[query_idx:(query_idx + len(query) + )] == query, "query: {}".format(query) + if pronoun_idx > query_idx: + text_list.insert(query_idx, "_") + text_list.insert(query_idx + len(query) + 1, "_") + text_list.insert(pronoun_idx + 2, "[") + text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") + else: + text_list.insert(pronoun_idx, "[") + text_list.insert(pronoun_idx + len(pronoun) + 1, "]") + text_list.insert(query_idx + 2, "_") + text_list.insert(query_idx + len(query) + 2 + 1, "_") + text = "".join(text_list) + example = tokenizer(text, max_seq_len=max_seq_length) + + if not is_test: + return example['input_ids'], example['token_type_ids'], label + else: + return example['input_ids'], example['token_type_ids'] + + +def quant_post(args, batch_size=8, algo='avg'): + paddle.enable_static() + place = paddle.set_device("gpu") + exe = paddle.static.Executor(place) + args.task_name = args.task_name.lower() + + train_ds = paddlenlp.datasets.load_dataset( + "clue", args.task_name, splits="dev") + + tokenizer = ErnieTokenizer.from_pretrained( + "../ernie-batchbatch-50w_400000/best_models/AFQMC/") + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=train_ds.label_list, + max_seq_length=128, + is_test=True) + train_ds = train_ds.map(trans_func, lazy=True) + + def test(): + batch_data = [[], []] + for data in train_ds: + batch_data[0].append(data[0]) + batch_data[1].append(data[1]) + if len(batch_data[0]) == batch_size: + input_ids = Pad(axis=0, pad_val=0)(batch_data[0]) + segment_ids = Pad(axis=0, pad_val=0)(batch_data[1]) + ones = np.ones_like(input_ids, dtype="int64") + seq_length = np.cumsum(ones, axis=-1) + + position_ids = seq_length - ones + attention_mask = np.expand_dims( + (input_ids == 0).astype("float32") * -1e9, axis=[1, 2]) + yield [input_ids, segment_ids] + batch_data = [[], []] + + paddleslim.quant.quant_post_static( + exe, + args.input_dir, + os.path.join(args.task_name + '_quant_models', algo + str(batch_size)), + save_model_filename='int8.pdmodel', + save_params_filename='int8.pdiparams', + algo=algo, + hist_percent=0.9999, + batch_generator=test, + model_filename='float.pdmodel', + params_filename='float.pdiparams', + quantizable_op_type=['matmul', 'matmul_v2'], + weight_bits=8, + weight_quantize_type='channel_wise_abs_max', + batch_nums=1, ) + + +if __name__ == '__main__': + paddle.enable_static() + for batch_size in [4, 8]: + for algo in ['abs_max', 'avg', 'mse', 'hist']: + quant_post(args, batch_size, algo) diff --git a/paddlenlp/transformers/distill_utils.py b/paddlenlp/transformers/distill_utils.py index 037cd8ceb849..7b296b308882 100644 --- a/paddlenlp/transformers/distill_utils.py +++ b/paddlenlp/transformers/distill_utils.py @@ -39,7 +39,6 @@ def calc_multi_relation_loss(loss_fct, Calculates loss for multiple Q-Q, K-K and V-V relation. It supports head-head relation, sample-sample relation and origin token-token relation. The final loss value could be balanced by weight `alpha` and `beta`. - Args: loss_fct (callable): Loss function for distillation. It only supports kl_div loss now. @@ -59,11 +58,9 @@ def calc_multi_relation_loss(loss_fct, beta (float): The weight for sample-sample relation. Defaults to 0.0. - Returns: Tensor: Weighted loss of token-token loss, head-head loss and sample-sample loss. - """ # Initialize head_num if num_relation_heads > 0 and num_relation_heads != s.shape[1]: @@ -140,7 +137,6 @@ def calc_multi_relation_loss(loss_fct, def calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0): """ Calculates loss for Q-Q, K-K, V-V relation from MiniLMv2. - Args: loss_fct (callable): Loss function for distillation. It only supports kl_div loss now. @@ -154,10 +150,8 @@ def calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0): The number of relation heads. 0 means `num_relation_heads` equals to origin head num. Defaults to 0. - Returns: Tensor: MiniLM loss value. - """ # Initialize head_num if num_relation_heads > 0 and num_relation_heads != s.shape[1]: @@ -197,7 +191,6 @@ def to_distill(self, expose attributes `outputs.q`, `outputs.k`, `outputs.v`, `outputs.scaled_qks`, `outputs.hidden_states`and `outputs.attentions` of the object for distillation. - It could be returned intermediate tensor using in MiniLM and TinyBERT strategy. """ @@ -435,4 +428,4 @@ def bert_forward(self, input_ids, token_type_ids=None, attention_mask=None): sequence_output, pooled_output = model(input_ids, token_type_ids, attention_mask) - return encoder.attentions, encoder.hidden_states + return encoder.attentions, encoder.hidden_states \ No newline at end of file