diff --git a/examples/model_compression/PP-MiniLM/README.md b/examples/model_compression/PP-MiniLM/README.md index 67f750f14ff6..211d5d45a4af 100644 --- a/examples/model_compression/PP-MiniLM/README.md +++ b/examples/model_compression/PP-MiniLM/README.md @@ -255,6 +255,8 @@ cd .. 由于开启了动态shape功能,因此需要设置获取shape的范围。Paddle Inference提供了相应的接口,即首先通过离线输入数据来统计出所有临时tensor的shape范围,TRT子图的tensor输入shape范围可直接根据上一步tune出来的结果来设置,即可完成自动shape范围设置。统计完成后,只需设置统计结果路径,即可启用tuned_dynamic_shape功能。在本案例中,只需要先设置--collect_shape参数,运行infer.py,然后再取消传入这个参数,再次运行infer.py。例如: +本案例是在NVIDIA Tesla T4 单卡上,cuda版本11.1、cudnn版本8.1、TensorRT版本7.2,使用inference/infer.py脚本,对量化后的模型进行预测。 + INT8预测脚本: ```shell @@ -277,27 +279,25 @@ python infer.py --task_name ${task} --model_path $MODEL_PATH --use_trt ``` ### 性能测试 -本案例是在NVIDIA Tesla T4 单卡上,cuda11.1、cudnn8.1、TensorRT7.2,使用inference/infer.py脚本,对量化后的模型进行预测。 -测试性能时采用了TNEWS数据集下的模型,下表三行分别是微调后的模型、OFA裁剪蒸馏后的模型、量化方法为mse、校准集数量为4的量化模型,计算dev上预测的总耗时(去除前20个steps)。 +测试性能环境同上,基于NVIDIA Tesla T4 单卡上,cuda版本11.1、cudnn版本8.1、TensorRT版本7.2。采用的是TNEWS数据集下训练的模型,下表三行分别是微调后的模型、OFA裁剪蒸馏后的模型、量化方法为mse、校准集数量为4的量化模型,计算dev上预测的总耗时(去除前20个steps)。 -可以发现借助PaddleSlim裁剪、量化后的模型比原BERT-base模型推理速度加速255.86%,其中裁剪可以加速87.98%。 - -| | 平均耗时(s) | 加速比 | -| ------------------ | ----------- | ------- | -| BERT | 20.64 | 0 | -| FP32 | 12.61 | 63.68% | -| FP32+裁剪 | 10.98 | 87.98% | -| FP32+裁剪+INT8量化 | 5.80 | 255.86% | - - -INT8预测脚本: +运行性能测试脚本可以得到FP32、裁剪后、量化后模型的耗时,取5个非--collect_shap阶段打印出的时长取平均: ```shell -sh infer.sh +sh infer_perf.sh ``` ```shell cd .. ``` + +可以发现借助PaddleSlim裁剪、量化后的模型比原BERT-base模型推理速度加速255.86%,其中裁剪可以加速87.98%。 + +| | 平均耗时(s) | 加速比 | +| ------------------ | ----------- | ------- | +| BERT | 20.64 | 0 | +| FP32 | 12.61 | 63.68% | +| FP32+裁剪 | 10.98 | 87.98% | +| FP32+裁剪+INT8量化 | 5.80 | 255.86% | diff --git a/examples/model_compression/PP-MiniLM/general_distill/general_distill.py b/examples/model_compression/PP-MiniLM/general_distill/general_distill.py index df26030eeb02..81f04f5e889f 100644 --- a/examples/model_compression/PP-MiniLM/general_distill/general_distill.py +++ b/examples/model_compression/PP-MiniLM/general_distill/general_distill.py @@ -33,7 +33,7 @@ from paddlenlp.transformers import LinearDecayWithWarmup from paddlenlp.transformers import RobertaModel, RobertaTokenizer from paddlenlp.transformers import ErnieModel, ErnieForSequenceClassification, ErnieTokenizer -from paddlenlp.transformers.distill_utils import to_distill, calc_minilm_loss_multi_relation +from paddlenlp.transformers.distill_utils import to_distill, calc_multi_relation_loss MODEL_CLASSES = { "roberta": (RobertaModel, RobertaTokenizer), @@ -245,6 +245,7 @@ def __init__(self, input_file, tokenizer, max_seq_length): line = line[:max_seq_length] tokenized_example = tokenizer(line, max_seq_len=max_seq_length) input_ids.append(tokenized_example['input_ids']) + self.inputs = np.asarray(input_ids) f.close() @@ -396,7 +397,7 @@ def do_train(args): input_ids = batch[0] attention_mask = paddle.unsqueeze( (input_ids == pad_token_id - ).astype(paddle.get_default_dtype()) * -1e9, + ).astype(paddle.get_default_dtype()) * -1e4, axis=[1, 2]) with paddle.amp.auto_cast( args.use_amp, @@ -408,35 +409,27 @@ def do_train(args): q_t, q_s = teacher.outputs.q, student.outputs.q batch_size = q_t.shape[0] pad_seq_len = q_t.shape[2] - loss_qr1, loss_qr2, loss_qr3 = calc_minilm_loss_multi_relation( + loss_q = calc_multi_relation_loss( kl_loss_fct, q_s, q_t, attention_mask, args.num_relation_heads, args.alpha, args.beta) del q_t, q_s # K-K relation k_t, k_s = teacher.outputs.k, student.outputs.k - loss_kr1, loss_kr2, loss_kr3 = calc_minilm_loss_multi_relation( + loss_k = calc_multi_relation_loss( kl_loss_fct, k_s, k_t, attention_mask, args.num_relation_heads, args.alpha, args.beta) del k_t, k_s # V-V relation v_t, v_s = teacher.outputs.v, student.outputs.v - loss_vr1, loss_vr2, loss_vr3 = calc_minilm_loss_multi_relation( + loss_v = calc_multi_relation_loss( kl_loss_fct, v_s, v_t, attention_mask, args.num_relation_heads, args.alpha, args.beta) del v_t, v_s - loss1 = (loss_qr1 + loss_kr1 + loss_vr1) - loss1 /= args.num_relation_heads * pad_seq_len * batch_size - - loss2 = loss_qr2 + loss_kr2 + loss_vr2 - loss2 /= args.num_relation_heads * pad_seq_len * batch_size - - loss3 = loss_qr3 + loss_kr3 + loss_vr3 - loss3 /= args.num_relation_heads * pad_seq_len * batch_size - loss = (1 - args.alpha - args.beta - ) * loss1 + loss2 * args.alpha + loss3 * args.beta + loss = loss_q + loss_k + loss_v + loss /= args.num_relation_heads * pad_seq_len * batch_size if args.use_amp: scaler.scale(loss).backward() @@ -453,10 +446,10 @@ def do_train(args): train_cost_avg.record(train_run_cost) if global_step % args.logging_steps == 0: logger.info( - "global step: %d, epoch: %d, batch: %d, loss: %f, loss1: %f, loss2: %f, loss3: %f," + "global step: %d, epoch: %d, batch: %d, loss: %f, " "lr: %f, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec" - % (global_step, epoch, step, loss, loss1, loss2, loss3, - optimizer.get_lr(), train_cost_avg.get_average(), + % (global_step, epoch, step, loss, optimizer.get_lr(), + train_cost_avg.get_average(), total_samples / args.logging_steps, total_samples / (args.logging_steps * train_cost_avg.get_average()))) total_samples = 0 diff --git a/examples/model_compression/PP-MiniLM/inference/infer.py b/examples/model_compression/PP-MiniLM/inference/infer.py index 9cc8422119a5..e080b248e10c 100644 --- a/examples/model_compression/PP-MiniLM/inference/infer.py +++ b/examples/model_compression/PP-MiniLM/inference/infer.py @@ -14,7 +14,6 @@ import argparse import os -import time from functools import partial import numpy as np @@ -112,7 +111,6 @@ def parse_args(): "--tokenizer_path", default='../general_distill/ernie-batchbatch-50w_400000/', type=str, - required=True, help="The directory for tokenizer.", ) parser.add_argument( "--model_path", @@ -190,7 +188,6 @@ def create_predictor(cls, args): config.switch_use_feed_fetch_ops(False) # could be deleted if args.use_trt: if args.int8: - print("int8") config.enable_tensorrt_engine( workspace_size=1 << 30, precision_mode=inference.PrecisionType.Int8, @@ -227,7 +224,6 @@ def create_predictor(cls, args): predictor.get_output_handle(name) for name in predictor.get_output_names() ] - cls.time = 0.0 return cls(predictor, input_handles, output_handles) @@ -235,10 +231,8 @@ def predict_batch(self, data): for input_field, input_handle in zip(data, self.input_handles): input_handle.copy_from_cpu(input_field.numpy() if isinstance( input_field, paddle.Tensor) else input_field) - time1 = time.time() self.predictor.run() paddle.fluid.core._cuda_synchronize(self.device) - self.time += time.time() - time1 output = [ output_handle.copy_to_cpu() for output_handle in self.output_handles ] @@ -258,9 +252,6 @@ def predict(self, dataset, collate_fn, args, batch_size=1): outputs = [] metric.reset() for i, data in enumerate(data_loader): - # warmup for performance test - if i < 20: - continue if len(data) == 2: output = self.predict_batch(data) else: @@ -272,7 +263,6 @@ def predict(self, dataset, collate_fn, args, batch_size=1): if len(data) > 2: res = metric.accumulate() print("task name: %s, acc: %s, " % (args.task_name, res), end='') - print("time: ", self.time) return outputs diff --git a/examples/model_compression/PP-MiniLM/inference/infer.sh b/examples/model_compression/PP-MiniLM/inference/infer.sh deleted file mode 100644 index b50f23e37fd7..000000000000 --- a/examples/model_compression/PP-MiniLM/inference/infer.sh +++ /dev/null @@ -1,26 +0,0 @@ -echo 原来的模型 -python infer.py --task_name tnews --model_path tnews/float --use_trt --collect_shape -python infer.py --task_name tnews --model_path tnews/float --use_trt -python infer.py --task_name tnews --model_path tnews/float --use_trt -python infer.py --task_name tnews --model_path tnews/float --use_trt -python infer.py --task_name tnews --model_path tnews/float --use_trt -python infer.py --task_name tnews --model_path tnews/float --use_trt - - -echo 裁剪后 -python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt --collect_shape -python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt -python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt -python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt -python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt -python infer.py --task_name tnews --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt - - -echo int8推理 -python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt --collect_shape -python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt -python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt -python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt -python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt -python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt - diff --git a/examples/model_compression/PP-MiniLM/inference/infer_perf.py b/examples/model_compression/PP-MiniLM/inference/infer_perf.py new file mode 100644 index 000000000000..4613abf8094f --- /dev/null +++ b/examples/model_compression/PP-MiniLM/inference/infer_perf.py @@ -0,0 +1,260 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time +from functools import partial +import numpy as np + +import paddle +from paddle import inference + +from paddlenlp.datasets import load_dataset +from paddlenlp.data import Stack, Tuple, Pad +from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer + +MODEL_CLASSES = {"ernie": (ErnieForSequenceClassification, ErnieTokenizer), } + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=512, + is_test=False): + """convert a glue example into necessary features""" + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + label = example['label'] + label = np.array([label], dtype=label_dtype) + # Convert raw text to feature + if 'sentence' in example: + example = tokenizer(example['sentence'], max_seq_len=max_seq_length) + elif 'sentence1' in example: + example = tokenizer( + example['sentence1'], + text_pair=example['sentence2'], + max_seq_len=max_seq_length) + elif 'keyword' in example: # CSL + sentence1 = " ".join(example['keyword']) + example = tokenizer( + sentence1, text_pair=example['abst'], max_seq_len=max_seq_length) + elif 'target' in example: # wsc + text, query, pronoun, query_idx, pronoun_idx = example['text'], example[ + 'target']['span1_text'], example['target']['span2_text'], example[ + 'target']['span1_index'], example['target']['span2_index'] + text_list = list(text) + assert text[pronoun_idx:(pronoun_idx + len(pronoun) + )] == pronoun, "pronoun: {}".format(pronoun) + assert text[query_idx:(query_idx + len(query) + )] == query, "query: {}".format(query) + if pronoun_idx > query_idx: + text_list.insert(query_idx, "_") + text_list.insert(query_idx + len(query) + 1, "_") + text_list.insert(pronoun_idx + 2, "[") + text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") + else: + text_list.insert(pronoun_idx, "[") + text_list.insert(pronoun_idx + len(pronoun) + 1, "]") + text_list.insert(query_idx + 2, "_") + text_list.insert(query_idx + len(query) + 2 + 1, "_") + text = "".join(text_list) + example = tokenizer(text, max_seq_len=max_seq_length) + + if not is_test: + return example['input_ids'], example['token_type_ids'], label + else: + return example['input_ids'], example['token_type_ids'] + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--task_name", default='afqmc', type=str) + parser.add_argument( + "--model_type", + default='ernie', + type=str, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_path", + default='./quant_models/model', + type=str, + required=True, + help="The path prefix of inference model to be used.", ) + parser.add_argument( + "--tokenizer_path", + default='../general_distill/ernie-batchbatch-50w_400000/', + type=str, + help="The directory for tokenizer.", ) + parser.add_argument( + "--device", + default="gpu", + choices=["gpu", "cpu", "xpu"], + help="Device selected for inference.", ) + parser.add_argument( + "--batch_size", + default=32, + type=int, + help="Batch size for predict.", ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--use_trt", + action='store_true', + help="Whether to use inference engin TensorRT.", ) + + parser.add_argument( + "--collect_shape", + action='store_true', + help="Whether collect shape range info.", ) + parser.add_argument( + "--int8", + action='store_true', + help="Whether int8 inference.", ) + args = parser.parse_args() + return args + + +class Predictor(object): + def __init__(self, predictor, input_handles, output_handles): + self.predictor = predictor + self.input_handles = input_handles + self.output_handles = output_handles + + @classmethod + def create_predictor(cls, args): + config = paddle.inference.Config(args.model_path + ".pdmodel", + args.model_path + ".pdiparams") + if args.device == "gpu": + # set GPU configs accordingly + config.enable_use_gpu(100, 0) + cls.device = paddle.set_device("gpu") + elif args.device == "cpu": + # set CPU configs accordingly, + # such as enable_mkldnn, set_cpu_math_library_num_threads + config.disable_gpu() + cls.device = paddle.set_device("cpu") + elif args.device == "xpu": + # set XPU configs accordingly + config.enable_xpu(100) + config.switch_use_feed_fetch_ops(False) # could be deleted + if args.use_trt: + if args.int8: + config.enable_tensorrt_engine( + workspace_size=1 << 30, + precision_mode=inference.PrecisionType.Int8, + max_batch_size=args.batch_size, + min_subgraph_size=5, + use_static=False, + use_calib_mode=False) + else: + config.enable_tensorrt_engine( + workspace_size=1 << 30, + precision_mode=inference.PrecisionType.Float32, + max_batch_size=args.batch_size, + min_subgraph_size=5, + use_static=False, + use_calib_mode=False) + print("Enable TensorRT is: {}".format( + config.tensorrt_engine_enabled())) + if args.collect_shape: + config.collect_shape_range_info( + os.path.dirname(args.model_path) + "/" + args.task_name + + '_shape_range_info.pbtxt') + else: + config.enable_tuned_tensorrt_dynamic_shape( + os.path.dirname(args.model_path) + "/" + args.task_name + + "_shape_range_info.pbtxt", True) + predictor = paddle.inference.create_predictor(config) + input_handles = [ + predictor.get_input_handle(name) + for name in predictor.get_input_names() + ] + output_handles = [ + predictor.get_output_handle(name) + for name in predictor.get_output_names() + ] + + return cls(predictor, input_handles, output_handles) + + def predict_batch(self, data, prin=False): + for input_field, input_handle in zip(data, self.input_handles): + input_handle.copy_from_cpu(input_field.numpy() if isinstance( + input_field, paddle.Tensor) else input_field) + self.predictor.run() + output = [ + output_handle.copy_to_cpu() for output_handle in self.output_handles + ] + return output + + def predict(self, dataset, collate_fn, args, batch_size=1): + batch_sampler = paddle.io.BatchSampler( + dataset, batch_size=batch_size, shuffle=False) + data_loader = paddle.io.DataLoader( + dataset=dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=0, + return_list=True) + time1 = time.time() + for i, data in enumerate(data_loader): + if i < 20: # skip warmup steps. + continue + output = self.predict_batch([data[0], data[1]]) + logits = paddle.to_tensor(output) + + print("time: ", time.time() - time1) + + +def main(): + paddle.seed(42) + args = parse_args() + + args.task_name = args.task_name.lower() + args.model_type = args.model_type.lower() + + predictor = Predictor.create_predictor(args) + + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + dev_ds = load_dataset('clue', args.task_name, splits='dev') + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path) + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=dev_ds.label_list, + max_seq_length=args.max_seq_length, + is_test=False) + + dev_ds = dev_ds.map(trans_func, lazy=True) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment + Stack(dtype="int64" if dev_ds.label_list else "float32") # label + ): fn(samples) + outputs = predictor.predict( + dev_ds, batch_size=args.batch_size, collate_fn=batchify_fn, args=args) + + +if __name__ == "__main__": + main() diff --git a/examples/model_compression/PP-MiniLM/inference/infer_perf.sh b/examples/model_compression/PP-MiniLM/inference/infer_perf.sh new file mode 100644 index 000000000000..00f09908821b --- /dev/null +++ b/examples/model_compression/PP-MiniLM/inference/infer_perf.sh @@ -0,0 +1,27 @@ +task = tnews +echo Inference of orgin FP32 model +python infer_perf.py --task_name ${task} --model_path tnews/float --use_trt --collect_shape +python infer_perf.py --task_name ${task} --model_path tnews/float --use_trt +python infer_perf.py --task_name ${task} --model_path tnews/float --use_trt +python infer_perf.py --task_name ${task} --model_path tnews/float --use_trt +python infer_perf.py --task_name ${task} --model_path tnews/float --use_trt +python infer_perf.py --task_name ${task} --model_path tnews/float --use_trt + + +echo After OFA +python infer_perf.py --task_name ${task} --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt --collect_shape +python infer_perf.py --task_name ${task} --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt +python infer_perf.py --task_name ${task} --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt +python infer_perf.py --task_name ${task} --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt +python infer_perf.py --task_name ${task} --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt +python infer_perf.py --task_name ${task} --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt + + +echo After quantization +python infer_perf.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt --collect_shape +python infer_perf.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt +python infer_perf.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt +python infer_perf.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt +python infer_perf.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt +python infer_perf.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt + diff --git a/paddlenlp/transformers/distill_utils.py b/paddlenlp/transformers/distill_utils.py index 7b296b308882..3f67c0d022b1 100644 --- a/paddlenlp/transformers/distill_utils.py +++ b/paddlenlp/transformers/distill_utils.py @@ -39,6 +39,7 @@ def calc_multi_relation_loss(loss_fct, Calculates loss for multiple Q-Q, K-K and V-V relation. It supports head-head relation, sample-sample relation and origin token-token relation. The final loss value could be balanced by weight `alpha` and `beta`. + Args: loss_fct (callable): Loss function for distillation. It only supports kl_div loss now. @@ -58,9 +59,11 @@ def calc_multi_relation_loss(loss_fct, beta (float): The weight for sample-sample relation. Defaults to 0.0. + Returns: Tensor: Weighted loss of token-token loss, head-head loss and sample-sample loss. + """ # Initialize head_num if num_relation_heads > 0 and num_relation_heads != s.shape[1]: @@ -150,8 +153,10 @@ def calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0): The number of relation heads. 0 means `num_relation_heads` equals to origin head num. Defaults to 0. + Returns: Tensor: MiniLM loss value. + """ # Initialize head_num if num_relation_heads > 0 and num_relation_heads != s.shape[1]: