Skip to content

Commit 4871622

Browse files
Add model Prohetnet (#1698)
* add Prohetnet model * update prohetnet * update format * pre commit * add prophetnet example * update tokenizer.py,run_train.sh,train_prophetnet.py * remove evaluate/gigaword/__init__.py Co-authored-by: smallv0221 <33639025+smallv0221@users.noreply.github.com>
1 parent 8139863 commit 4871622

File tree

16 files changed

+4850
-0
lines changed

16 files changed

+4850
-0
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Prophetnet
2+
3+
## 模型简介
4+
5+
ProphetNet(先知网络)是一种新型的 seq2seq 预训练模型。在训练时,Prophetnet 每一时刻将会学习同时预测未来的 N 个字符,这种自监督学习目标可以使得模型考虑未来更远的字符,防止模型对强局部相关(strong
6+
local correlation)过拟合。
7+
8+
本项目是 Prophetnet 在 PaddlePaddle 2.2 上开源实现的文本摘要的例子,包含了在 CNN/DailyMail 数据集,Gigaword 数据集上微调和生成的代码。
9+
10+
### 项目依赖
11+
12+
```
13+
pip install -r requirements.txt
14+
python -m pip install paddlepaddle-gpu==2.2.2.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
15+
pip install paddlenlp==2.2.3
16+
```
17+
18+
### 代码结构说明
19+
20+
以下是本项目主要代码结构及说明:
21+
22+
```text
23+
├── train_prophetnet.py # 模型finetune主程序入口
24+
├── generate.py # 模型生成主程序入口
25+
├── eval.py # 生成结果评估入口
26+
├── uncase_tokenize_data.py # 数据预处理
27+
├── uncompress_data.sh # 数据解压脚本
28+
├── run_train.sh # 模型训练脚本
29+
├── run_eval.sh # 模型评估脚本
30+
├── requirements.txt # 环境依赖文件
31+
└── README.md # 文档说明
32+
```
33+
34+
### 数据准备
35+
36+
GLGE 数据集下载:[链接](https://drive.google.com/file/d/1F4zppa9Gqrh6iNyVsZJkxfbm5waalqEA/view)
37+
38+
GLGE 测试集下载:[链接](https://drive.google.com/file/d/11lDXIG87dChIfukq3x2Wx4r5_duCRm_J/view)
39+
40+
将glge_public.tar与glge_hidden_v1.1.tar.gz放入到项目根目录下。
41+
42+
```
43+
bash uncompress_data.sh
44+
```
45+
46+
### 下载预训练权重与词表
47+
48+
模型权重和词表[下载链接](https://pan.baidu.com/s/1FOnd01rNvDJoONYegacq1Q), 提取码:o28q,下载后放入项目根目录。
49+
50+
### 数据预处理
51+
52+
```
53+
python uncase_tokenize_data.py --dataset <DATASET>
54+
```
55+
56+
说明:
57+
58+
- `<DATASET>`可选`cnndm`, `gigaword`.
59+
60+
### 模型训练
61+
62+
```
63+
bash run_train.sh <DATASET>
64+
```
65+
66+
或直接运行finetune程序
67+
68+
- cnndm:
69+
70+
```
71+
python train_prophetnet.py \
72+
--dataset=cnndm \
73+
--pretrained_model_path=./model_state.pdparams \
74+
--batch_size=4 \
75+
--epochs=4 \
76+
--lr=0.0001 \
77+
--warmup_init_lr=1e-07 \
78+
--warmup_updates=1000 \
79+
--clip_norm=0.1 \
80+
--num_workers=4 \
81+
--output_dir=./ckpt/cnndm
82+
```
83+
84+
- gigaword:
85+
86+
```
87+
python train_prophetnet.py \
88+
--dataset=gigaword \
89+
--pretrained_model_path=./model_state.pdparams \
90+
--batch_size=16 \
91+
--epochs=6 \
92+
--lr=0.0001 \
93+
--warmup_init_lr=1e-07 \
94+
--warmup_updates=1000 \
95+
--clip_norm=0.1 \
96+
--num_workers=8 \
97+
--output_dir=./ckpt/gigaword
98+
```
99+
100+
其中参数释义如下:
101+
102+
- `dataset` 指定数据集,可选cnndm和gigaword
103+
104+
- `pretrained_model_path` 本地预训练模型初始化权重文件路径,例如: ./model_state.pdparams。
105+
106+
- `batch_size` 表示训练样本批大小。
107+
108+
- `epochs` 表示训练轮数。
109+
110+
- `lr` 表示学习率
111+
112+
- `warmup_init_lr` 表示预热学习率
113+
114+
- `warmup_updates` 表示预热学习步数
115+
116+
- `clip_norm` 表示梯度裁剪
117+
118+
- `num_workers` 指定数据加载规模
119+
120+
- `output_idr` 指定微调结果权重存放路径
121+
122+
已经finetune好的模型权重:
123+
124+
- cnndm : [链接](https://pan.baidu.com/s/1cemrUDxkqEW9raoasJ_VKw), 提取码:1egi
125+
126+
- gigaword : [链接](https://pan.baidu.com/s/1qRH2FStT3vNQtDjZLkYJBQ), 提取码:on5v
127+
128+
### 模型评估
129+
130+
使用prophetNet源码的[评估脚本](https://pan.baidu.com/s/1FOnd01rNvDJoONYegacq1Q), 此脚本依赖于pyrouge,需要提前安装rouge。
131+
132+
```
133+
pip install git+https://github.com/pltrdy/pyrouge
134+
```
135+
136+
```
137+
bash run_eval.sh <DATASET>
138+
```
139+
140+
或直接运行模型生成程序
141+
142+
- cnndm:
143+
144+
```
145+
python generate.py \
146+
--dataset=cnndm \
147+
--vocab_file=./prophetnet.tokenizer \
148+
--output_path=./generate/cnndm/generate.txt \
149+
--min_target_length=45 \
150+
--max_target_length=110 \
151+
--decode_strategy=beam_search \
152+
--num_beams=4 \
153+
--length_penalty=1.2 \
154+
--batch_size=16 \
155+
--ignore_pad_token_for_loss=True \
156+
--early_stopping=True \
157+
--logging_steps=100 \
158+
--device=gpu
159+
160+
python eval.py --dataset cnndm --generated ./generate/cnndm/generate.txt
161+
```
162+
163+
- gigaword:
164+
165+
```
166+
python generate.py \
167+
--dataset=gigaword \
168+
--vocab_file=./prophetnet.tokenizer \
169+
--output_path=./generate/gigaword/generate.txt \
170+
--min_target_length=1 \
171+
--max_target_length=200 \
172+
--decode_strategy=beam_search \
173+
--num_beams=4 \
174+
--length_penalty=1.6 \
175+
--batch_size=16 \
176+
--ignore_pad_token_for_loss=True \
177+
--early_stopping=True \
178+
--logging_steps=100 \
179+
--device=gpu
180+
181+
python eval.py --dataset gigaword --generated ./generate/gigaword/generate.txt
182+
```
183+
184+
其中参数释义如下:
185+
186+
- `dataset` 指定数据集,可选cnndm和gigaword
187+
188+
- `vocab_file` 指定词表文件
189+
190+
- `output_path` 指定生成结果存放路径
191+
192+
- `min_target_length` 指定解码最短长度
193+
194+
- `max_target_length` 指定解码最大长度
195+
196+
- `decode_strategy` 指定解码策略
197+
198+
- `num_beams` 指定beam_search解码宽度
199+
200+
- `length_penalty` 指定beam_search解码的长度指数惩罚
201+
202+
- `batch_size` 指定评估样本批大小
203+
204+
- `ignore_pad_token_for_loss` 表示计算loss时忽略padding
205+
206+
- `early_stopping` 指定生成结束符是否停止预测
207+
208+
- `logging_steps` 指定日志打印间隔
209+
210+
- `device` 指定使用设备
211+
212+
### 微调测试精度
213+
214+
> #### 在CNN/DM数据集的测试效果如下表。
215+
216+
|网络 |opt|batch_size|数据集|ROUGE_1|ROUGE_2|ROUGE_L|
217+
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
218+
|prophetnet-large-uncased|Adam|4|CNN/DM|44.17|21.24|41.36|
219+
220+
> #### 在gigaword数据集的测试效果如下表。
221+
222+
|网络 |opt|batch_size|数据集|ROUGE_1|ROUGE_2|ROUGE_L|
223+
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
224+
|prophetnet-large-uncased|Adam|16|gigaword|38.92|19.81|36.06|
225+
226+
### 实验环境
227+
228+
- GPU RTX3090 * 1, CPU Intel i7-11700k
229+
- Ubuntu 18.04
230+
231+
### 参考文献
232+
233+
1. Qi W, Yan Y, Gong Y, et al. Prophetnet: Predicting future n-gram for sequence-to-sequence pre-training[J]. arXiv
234+
preprint arXiv:2001.04063, 2020.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import argparse
2+
import os
3+
import re
4+
import sys
5+
from os import listdir
6+
from os.path import isfile, join
7+
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument(
10+
"--dataset",
11+
type=str,
12+
help="choose from all, or 1 of 8 dataset like cnndm, gigaword etc.")
13+
parser.add_argument("--generated", type=str, help="generated output file.")
14+
15+
args = parser.parse_args()
16+
17+
data_root_path = 'data'
18+
19+
support_dataset = ['cnndm', 'gigaword']
20+
files2rouge_template = '.*ROUGE-1 Average_F: (?P<rouge1_f>\d+(\.\d*)?|\.\d+).*ROUGE-2 Average_F: (?P<rouge2_f>\d+(\.\d*)?|\.\d+).*ROUGE-L Average_F: (?P<rougeL_f>\d+(\.\d*)?|\.\d+).*'
21+
# gigaword_template='.*ROUGE-1: (?P<rouge1_f>\d+(\.\d*)?|\.\d+).*ROUGE-2: (?P<rouge2_f>\d+(\.\d*)?|\.\d+).*ROUGE-L: (?P<rougeL_f>\d+(\.\d*)?|\.\d+).*'
22+
qg_template = '.*Bleu_4: (?P<bleu4>\d+(\.\d*)?|\.\d+).*METEOR: (?P<meteor>\d+(\.\d*)?|\.\d+).*ROUGE_L: (?P<rougeL>\d+(\.\d*)?|\.\d+).*'
23+
personachat_template = '.*?(?P<d1>[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?).*?(?P<d2>[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?).*Bleu_1: (?P<bleu1>\d+(\.\d*)?|\.\d+).*Bleu_2: (?P<bleu2>\d+(\.\d*)?|\.\d+).*'
24+
25+
26+
def scale_up(d):
27+
return {k: float(d[k]) * 100 for k in d.keys()}
28+
29+
30+
def eval_one_dataset():
31+
golden_file = f"{data_root_path}/{args.dataset}_data/test.tgt"
32+
33+
eval_template = {
34+
'cnndm':
35+
f"python ./evaluate/cnndm/postprocess_cnn_dm.py --generated {generated_file} --golden {golden_file}",
36+
'gigaword':
37+
f"python ./evaluate/gigaword/eval.py --perl --pred {generated_file} --gold {golden_file}",
38+
}
39+
40+
cmd = eval_template[args.dataset]
41+
try:
42+
output = os.popen(cmd).read()
43+
if args.dataset in ['cnndm', 'gigaword']:
44+
d = re.search(files2rouge_template,
45+
output.replace("\n", " ")).groupdict()
46+
d = scale_up(d)
47+
print(
48+
f"{args.dataset}\trouge1/rouge2/rougeL\t{d['rouge1_f']:.2f}/{d['rouge2_f']:.2f}/{d['rougeL_f']:.2f}"
49+
)
50+
except:
51+
print("Unexpected error:", sys.exc_info()[0])
52+
print(f"{args.dataset} evaluate failed!")
53+
54+
55+
if args.dataset != 'all':
56+
generated_file = args.generated
57+
eval_one_dataset()
58+
else:
59+
output_root_path = args.generated
60+
onlyfolders = [
61+
f for f in listdir(output_root_path)
62+
if not isfile(join(args.generated, f))
63+
]
64+
for dataset in support_dataset:
65+
for folder in onlyfolders:
66+
if folder.startswith(dataset):
67+
for hypo_file in listdir(args.generated + '/' + folder):
68+
if 'hypo' in hypo_file or 'score' in hypo_file:
69+
generated_file = args.generated + '/' + folder + '/' + hypo_file
70+
print(f"{dataset}\tpredict_file:{generated_file}")
71+
args.dataset = dataset
72+
args.gnerated = generated_file
73+
eval_one_dataset()

0 commit comments

Comments
 (0)