-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MPNet Model #869
Merged
Merged
Add MPNet Model #869
Changes from 3 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
5a088b4
add mpnet
JunnYu 5331f7d
update
JunnYu 3b214db
update tokenizer and update readme
JunnYu 1686759
Merge branch 'develop' into add_mpnet
JunnYu e073781
update readme & add docs
JunnYu 5232e53
rm unused figure
JunnYu 74c5691
update
JunnYu 5a98758
Merge branch 'develop' into add_mpnet
JunnYu 37aef51
update
JunnYu 67953ef
update copyright
JunnYu 2ff2a03
Merge branch 'develop' into add_mpnet
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
# MPNet with PaddleNLP | ||
|
||
[MPNet: Masked and Permuted Pre-training for Language Understanding - Microsoft Research](https://www.microsoft.com/en-us/research/publication/mpnet-masked-and-permuted-pre-training-for-language-understanding/) | ||
|
||
**摘要:** | ||
BERT adopts masked language modeling (MLM) for pre-training and is one of the most successful pre-training models. Since BERT neglects dependency among predicted tokens, XLNet introduces permuted language modeling (PLM) for pretraining to address this problem. However, XLNet does not leverage the full position information of a sentence and thus suffers from position discrepancy between pre-training and fine-tuning. In this paper, we propose MPNet, a novel pre-training method that inherits the advantages of BERT and XLNet and avoids their limitations. MPNet leverages the dependency among predicted tokens through permuted language modeling (vs. MLM in BERT), and takes auxiliary position information as input to make the model see a full sentence and thus reducing the position discrepancy (vs. PLM in XLNet). We pre-train MPNet on a large-scale dataset (over 160GB text corpora) and fine-tune on a variety of down-streaming tasks (GLUE, SQuAD, etc). Experimental results show that MPNet outperforms MLM and PLM by a large margin, and achieves better results on these tasks compared with previous state-of-the-art pre-trained methods (e.g., BERT, XLNet, RoBERTa) under the same model setting. The code and the pre-trained models are available at: https://github.com/microsoft/MPNet. | ||
|
||
本项目是 MPNet 在 Paddle 2.x上的开源实现。 | ||
|
||
## 原论文效果 | ||
<p align="center"> | ||
<img src="./figure/QQP.png" width="100%" /> | ||
</p> | ||
<p align="center"> | ||
<img src="./figure/SQuAD.png" width="100%" /> | ||
</p> | ||
|
||
## 快速开始 | ||
|
||
### 模型精度对齐 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove. |
||
运行`python compare.py`,对比huggingface与paddle之间的精度,我们可以发现精度的平均误差在10^-6量级,最大误差在10^-5量级(更换不同的输入,误差会发生变化)。 | ||
```python | ||
python compare.py | ||
# meandif tensor(6.5154e-06) | ||
# maxdif tensor(4.1485e-05) | ||
``` | ||
|
||
### 下游任务微调 | ||
|
||
#### 1、GLUE | ||
以QQP数据集为例,运行其他glue数据集,请参考`train.sh`文件。(超参数遵循原论文的仓库的[README](https://github.com/microsoft/MPNet/blob/master/MPNet/README.glue.md)) | ||
|
||
##### (1)模型微调: | ||
```shell | ||
unset CUDA_VISIBLE_DEVICES | ||
cd glue | ||
python -m paddle.distributed.launch --gpus "0" run_glue.py \ | ||
--model_type mpnet \ | ||
--model_name_or_path mpnet-base \ | ||
--task_name qqp \ | ||
--max_seq_length 128 \ | ||
--batch_size 32 \ | ||
--learning_rate 1e-5 \ | ||
--scheduler_type linear \ | ||
--weight_decay 0.1 \ | ||
--warmup_steps 5666 \ | ||
--max_steps 113272 \ | ||
--logging_steps 500 \ | ||
--save_steps 2000 \ | ||
--seed 42 \ | ||
--output_dir qqp/ \ | ||
--device gpu | ||
``` | ||
其中参数释义如下: | ||
- `model_type` 指示了模型类型,当前支持BERT、ELECTRA、ERNIE、CONVBERT、MPNET模型。 | ||
- `model_name_or_path` 模型名称或者路径,其中mpnet模型当前仅支持mpnet-base几种规格。 | ||
- `task_name` 表示 Fine-tuning 的任务,当前支持CoLA、SST-2、MRPC、STS-B、QQP、MNLI、QNLI、RTE和WNLI。 | ||
- `max_seq_length` 表示最大句子长度,超过该长度将被截断。 | ||
- `batch_size` 表示每次迭代**每张卡**上的样本数目。 | ||
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。 | ||
- `scheduler_type` scheduler类型,可选linear和cosine。 | ||
- `warmup_steps` warmup步数。 | ||
- `max_steps` 表示最大训练步数。 | ||
- `logging_steps` 表示日志打印间隔。 | ||
- `save_steps` 表示模型保存及评估间隔。 | ||
- `output_dir` 表示模型保存路径。 | ||
- `device` 表示使用的设备类型。默认为GPU,可以配置为CPU、GPU、XPU。若希望使用多GPU训练,将其设置为GPU,同时环境变量CUDA_VISIBLE_DEVICES配置要使用的GPU id。 | ||
|
||
##### (2)模型预测: | ||
```bash | ||
cd glue | ||
python run_predict.py --task_name qqp --ckpt_path qqp/best-qqp_ft_model_106000.pdparams | ||
``` | ||
|
||
##### (3)压缩template文件夹为zip文件,然后提交到[GLUE排行榜](https://gluebenchmark.com/leaderboard): | ||
|
||
###### GLUE排行榜结果: | ||
<p align="center"> | ||
<img src="figure/glue.jpg" width="100%" /> | ||
</p> | ||
|
||
|
||
###### GLUE开发集结果: | ||
|
||
| task | cola | sst-2 | mrpc | sts-b | qqp | mnli | qnli | rte | avg | | ||
|--------------------------------|-------|-------|-------------|------------------|-------------|------|-------|-------|-------| | ||
| **metric** | **mcc** | **acc** | **acc/f1** | **pearson/spearman** | **acc/f1** | **acc(m/mm)** | **acc** | **acc** | | | ||
| Paper | **65.0** | **95.5** | **91.8**/空 | 91.1/空 | **91.9**/空 | **88.5**/空 | 93.3 | 85.8 | **87.9** | | ||
| Mine | 64.4 | 95.4 | 90.4/93.1 | **91.6**/91.3 | **91.9**/89.0 | 87.7/88.2 | **93.6** | **86.6** | 87.7 | | ||
|
||
###### GLUE测试集结果对比: | ||
|
||
| task | cola | sst-2 | mrpc | sts-b | qqp | mnli-m | qnli | rte | avg | | ||
|--------------------------------|-------|-------|-------|-------|-----|-------|-------|-------|----------| | ||
| **metric** | **mcc** | **acc** | **acc/f1** | **pearson/spearman** | **acc/f1** | **acc(m/mm)** | **acc** | **acc** | | | ||
| Paper | **64.0** | **96.0** | 89.1/空 | 90.7/空 | **89.9**/空 | **88\.5**/空 | 93\.1 | 81.0 | **86.5** | | ||
| Mine | 60.5 | 95.9 | **91.6**/88.9 | **90.8**/90.3 | 89.7/72.5 | 87.6/86.6 | **93.3** | **82.4** | **86.5** | | ||
|
||
#### 2、SQuAD v1.1 | ||
|
||
使用Paddle提供的预训练模型运行SQuAD v1.1数据集的Fine-tuning | ||
|
||
```bash | ||
unset CUDA_VISIBLE_DEVICES | ||
cd squad | ||
python -m paddle.distributed.launch --gpus "0" run_squad.py \ | ||
--model_type mpnet \ | ||
--model_name_or_path mpnet-base \ | ||
--max_seq_length 512 \ | ||
--batch_size 16 \ | ||
--learning_rate 2e-5 \ | ||
--num_train_epochs 4 \ | ||
--scheduler_type linear \ | ||
--logging_steps 25 \ | ||
--save_steps 25 \ | ||
--warmup_proportion 0.1 \ | ||
--weight_decay 0.1 \ | ||
--output_dir squad1.1/ \ | ||
--device gpu \ | ||
--do_train \ | ||
--seed 42 \ | ||
--do_predict | ||
``` | ||
|
||
训练过程中模型会自动对结果进行评估,其中最好的结果如下所示: | ||
|
||
```python | ||
{ | ||
"exact": 86.84957426679281, | ||
"f1": 92.82031917884066, | ||
"total": 10570, | ||
"HasAns_exact": 86.84957426679281, | ||
"HasAns_f1": 92.82031917884066, | ||
"HasAns_total": 10570 | ||
} | ||
``` | ||
|
||
#### 3、SQuAD v2.0 | ||
对于 SQuAD v2.0,按如下方式启动 Fine-tuning: | ||
|
||
```bash | ||
unset CUDA_VISIBLE_DEVICES | ||
cd squad | ||
python -m paddle.distributed.launch --gpus "0" run_squad.py \ | ||
--model_type mpnet \ | ||
--model_name_or_path mpnet-base \ | ||
--max_seq_length 512 \ | ||
--batch_size 16 \ | ||
--learning_rate 2e-5 \ | ||
--num_train_epochs 4 \ | ||
--scheduler_type linear \ | ||
--logging_steps 200 \ | ||
--save_steps 200 \ | ||
--warmup_proportion 0.1 \ | ||
--weight_decay 0.1 \ | ||
--output_dir squad2/ \ | ||
--device gpu \ | ||
--do_train \ | ||
--seed 42 \ | ||
--do_predict \ | ||
--version_2_with_negative | ||
``` | ||
|
||
* `version_2_with_negative`: 使用squad2.0数据集和评价指标的标志。 | ||
|
||
训练过程中模型会自动对结果进行评估,其中最好的结果如下所示: | ||
|
||
```python | ||
{ | ||
"exact": 82.27912069401162, | ||
"f1": 85.2774124891565, | ||
"total": 11873, | ||
"HasAns_exact": 80.34750337381917, | ||
"HasAns_f1": 86.35268530427743, | ||
"HasAns_total": 5928, | ||
"NoAns_exact": 84.20521446593776, | ||
"NoAns_f1": 84.20521446593776, | ||
"NoAns_total": 5945, | ||
"best_exact": 82.86869367472417, | ||
"best_exact_thresh": -2.450321674346924, | ||
"best_f1": 85.67634263296013, | ||
"best_f1_thresh": -2.450321674346924 | ||
} | ||
``` | ||
|
||
# Tips: | ||
- 对于SQUAD任务:根据这个[issues](https://github.com/microsoft/MPNet/issues/3)所说,论文中汇报的是`best_exact`和`best_f1`。 | ||
- 对于GLUE任务:根据这个[issues](https://github.com/microsoft/MPNet/issues/7)所说,部分任务采用了热启动初始化的方法。 | ||
|
||
# Reference | ||
|
||
```bibtex | ||
@article{song2020mpnet, | ||
title={MPNet: Masked and Permuted Pre-training for Language Understanding}, | ||
author={Song, Kaitao and Tan, Xu and Qin, Tao and Lu, Jianfeng and Liu, Tie-Yan}, | ||
journal={arXiv preprint arXiv:2004.09297}, | ||
year={2020} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please delete |
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddlenlp.transformers.mpnet.modeling import MPNetForMaskedLM as PDMPNetForMaskedLM | ||
from transformers.models.mpnet.modeling_mpnet import (MPNetForMaskedLM as | ||
PTMPNetForMaskedLM, ) | ||
import torch | ||
import paddle | ||
|
||
paddle.set_device("cpu") | ||
|
||
pd_model = PDMPNetForMaskedLM.from_pretrained("mpnet-base") | ||
pd_model.eval() | ||
pt_model = PTMPNetForMaskedLM.from_pretrained("mpnet-base") | ||
pt_model.eval() | ||
|
||
with paddle.no_grad(): | ||
pd_outputs = pd_model( | ||
paddle.to_tensor([[523, 123, 6123, 523, 5213, 632], | ||
[5232, 1231, 6133, 5253, 5555, 6212]]))[0] | ||
|
||
with torch.no_grad(): | ||
pt_outputs = pt_model( | ||
torch.tensor([[523, 123, 6123, 523, 5213, 632], | ||
[5232, 1231, 6133, 5253, 5555, 6212]]))[0] | ||
|
||
|
||
def compare(a, b): | ||
a = torch.tensor(a.numpy()).float() | ||
b = torch.tensor(b.numpy()).float() | ||
meandif = (a - b).abs().mean() | ||
maxdif = (a - b).abs().max() | ||
print("mean difference:", meandif) | ||
print("max difference:", maxdif) | ||
|
||
|
||
compare(pd_outputs, pt_outputs) | ||
# meandif tensor(6.5154e-06) | ||
# maxdif tensor(4.1485e-05) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from collections import OrderedDict | ||
import argparse | ||
|
||
huggingface_to_paddle = { | ||
".attn.": ".", | ||
"intermediate.dense": "ffn", | ||
"output.dense": "ffn_output", | ||
".output.LayerNorm.": ".layer_norm.", | ||
".LayerNorm.": ".layer_norm.", | ||
"lm_head.decoder.bias": "lm_head.decoder_bias", | ||
} | ||
|
||
skip_weights = ["lm_head.decoder.weight", "lm_head.bias"] | ||
dont_transpose = [ | ||
"_embeddings.weight", | ||
".LayerNorm.weight", | ||
".layer_norm.weight", | ||
"relative_attention_bias.weight", | ||
] | ||
|
||
|
||
def convert_pytorch_checkpoint_to_paddle(pytorch_checkpoint_path, | ||
paddle_dump_path): | ||
import torch | ||
import paddle | ||
|
||
pytorch_state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu") | ||
paddle_state_dict = OrderedDict() | ||
for k, v in pytorch_state_dict.items(): | ||
transpose = False | ||
if k in skip_weights: | ||
continue | ||
if k[-7:] == ".weight": | ||
if not any([w in k for w in dont_transpose]): | ||
if v.ndim == 2: | ||
v = v.transpose(0, 1) | ||
transpose = True | ||
oldk = k | ||
for huggingface_name, paddle_name in huggingface_to_paddle.items(): | ||
k = k.replace(huggingface_name, paddle_name) | ||
|
||
print(f"Converting: {oldk} => {k} | is_transpose {transpose}") | ||
paddle_state_dict[k] = v.data.numpy() | ||
|
||
paddle.save(paddle_state_dict, paddle_dump_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--pytorch_checkpoint_path", | ||
default="weights/hg/mpnet-base/pytorch_model.bin", | ||
type=str, | ||
required=False, | ||
help="Path to the Pytorch checkpoint path.", ) | ||
parser.add_argument( | ||
"--paddle_dump_path", | ||
default="weights/pd/mpnet-base/model_state.pdparams", | ||
type=str, | ||
required=False, | ||
help="Path to the output Paddle model.", ) | ||
args = parser.parse_args() | ||
convert_pytorch_checkpoint_to_paddle(args.pytorch_checkpoint_path, | ||
args.paddle_dump_path) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# task name ["cola","sst-2","mrpc","sts-b","qqp","mnli", "rte", "qnli"] | ||
|
||
python run_predict.py --task_name qqp --ckpt_path qqp/best-qqp_ft_model_106000.pdparams |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please translate to Chinese.