Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unimo model and fix generate api #891

Merged
merged 24 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
43c4edd
fix unified transformer dtype problem
smallv0221 Jul 9, 2021
18b6860
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
smallv0221 Jul 9, 2021
6739c85
fix win dtype bug
smallv0221 Jul 14, 2021
9ba42f9
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
smallv0221 Jul 14, 2021
135e0e1
Fix plato-2 and plato-mini dtype bug
smallv0221 Jul 19, 2021
fd2a5a5
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
smallv0221 Jul 19, 2021
8567c64
Fix plato-2 tokenization
smallv0221 Jul 19, 2021
cea6ba3
Merge branch 'develop' into yxp0707
LiuChiachi Jul 19, 2021
67f5595
Refine some doc
smallv0221 Jul 20, 2021
3423001
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
smallv0221 Jul 20, 2021
4163691
Merge branch 'yxp0707' of https://github.com/smallv0221/PaddleNLP int…
smallv0221 Jul 20, 2021
9ecfcd5
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
smallv0221 Jul 25, 2021
41527aa
Add general k support for topk sampling
smallv0221 Jul 25, 2021
411e57c
fix seed
smallv0221 Jul 26, 2021
528fd2d
minor fix
smallv0221 Jul 26, 2021
9a42e57
Fix unitransformer readme
smallv0221 Jul 27, 2021
d5f6f21
Merge branch 'develop' into yxp0707
FrostML Jul 27, 2021
e5f684f
topk kernel optimization
smallv0221 Jul 28, 2021
93922b1
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
smallv0221 Jul 28, 2021
e0e7e3c
Merge branch 'yxp0707' of https://github.com/smallv0221/PaddleNLP int…
smallv0221 Jul 28, 2021
fcaf7bb
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
smallv0221 Aug 4, 2021
9677cf9
add unimo model and fix generate api
smallv0221 Aug 17, 2021
649fd39
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
smallv0221 Aug 17, 2021
136813e
add 3 datasets for unimo-text
smallv0221 Aug 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions examples/text_generation/unimo-text/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# 千言:面向事实一致性的生成评测比赛baseline

## 比赛简介

自然语言生成旨在让机器能够像人一样使用自然语言进行表达和交互,它是人工智能领域重要的前沿课题,近年来受到学术界和工业界广泛关注。

随着神经网络生成模型特别是预训练语言模型的迅速发展,机器生成文本的可读性和流畅性不断提升。然而,自动生成的文本中依然经常出现不符合原文或背景的错误事实描述,这种生成的事实一致性问题是自然语言生成进行落地应用的主要障碍之一,并逐渐受到研究学者的关注。鉴于当前国内外关于事实一致性的生成评测比赛十分匮乏,为了促进自然语言生成的技术发展和实际应用,我们计划组织面向事实一致性的生成评测比赛。

在此比赛中,我们将提供三个对事实一致性有较高要求的生成任务,包括文案生成、摘要生成和问题生成。同时,在系统评价中,我们将结合文本流畅性和事实一致性两项指标综合评估参赛生成系统的水平。通过这样的任务设定和评价方式,此评测将有助于研究者和开发者更多关注自然语言生成的事实一致性难题,并为大家提供学术交流平台,从而进一步提升自然语言生成的研究水平,推动相关技术的应用发展。

本比赛得到中国中文信息学会自然语言生成专业委员会(筹)支持,将在2021年11月7日首届中国自然语言生成大会(CCNLG-2021)召开评测研讨会,并在大会上对获奖团队颁奖。


## 快速开始

### 数据准备

比赛使用三个任务数据集测试参赛系统的生成能力,包括文案生成、摘要生成和问题生成:

- 文案生成根据结构化的商品信息生成合适的广告文案;
- 摘要生成是为输入文档生成简洁且包含关键信息的简洁文本;
- 问题生成则是根据给定段落以及答案生成适合的问题。


### 模型训练

运行如下命令即可在样例训练集上进行finetune,并在样例验证集上进行验证

```shell
# GPU启动,参数`--gpus`指定训练所用的GPU卡号,可以是单卡,也可以多卡
unset CUDA_VISIBLE_DEVICES
python -m paddle.distributed.launch --gpus "0" --log_dir ./log run_gen.py \
--dataset_name=dureader_qg \
--model_name_or_path=unimo-text-1.0 \
--save_dir=./unimo/checkpoints \
--logging_steps=100 \
--save_steps=100000 \
--epochs=6 \
--batch_size=16 \
--learning_rate=5e-5 \
--warmup_propotion=0.02 \
--weight_decay=0.01 \
--max_seq_len=512 \
--max_target_len=30 \
--do_train \
--do_predict \
--device=gpu
```

其中参数释义如下:
- `gpus` 指示了训练所用的GPU卡号。
- `dataset_name` 数据集名称,dureader_qg、advertisegen和lcsts_new分别对应问题生成、文案生成和摘要生成三个任务。
- `model_name_or_path` 指示了finetune使用的具体预训练模型,可以是PaddleNLP提供的预训练模型,或者是本地的预训练模型。如果使用本地的预训练模型,可以配置本地模型的目录地址,例如: ./checkpoints/model_xx/,目录中需包含paddle预训练模型model_state.pdparams。如果使用PaddleNLP提供的预训练模型,可以选择下面其中之一。

| PaddleNLP提供的预训练模型 |
|---------------------------------|
| unimo-text-1.0 |
| unimo-text-1.0-large |

- `save_dir` 表示模型的保存路径。
- `logging_steps` 表示日志打印间隔。
- `save_steps` 表示模型保存及评估间隔。
- `seed` 表示随机数生成器的种子。
- `epochs` 表示训练轮数。
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。
- `weight_decay` 表示AdamW优化器中使用的weight_decay的系数。
- `warmup_propotion` 表示学习率逐渐升高到基础学习率(即上面配置的learning_rate)所需要的迭代数占总步数的比例,最早的使用可以参考[这篇论文](https://arxiv.org/pdf/1706.02677.pdf)。
- `max_seq_len` 模型输入序列的最大长度。
- `max_target_len` 模型训练时标签的最大长度。
- `do_train` 是否进行训练。
- `do_predict` 是否进行预测,在验证集上会自动评估。
- `device` 表示使用的设备,从gpu和cpu中选择。

更多参数详情和参数的默认值请参考`args.py`。

程序运行时将会自动进行训练和验证,训练过程中会自动保存模型在指定的`save_dir`中。
如:
```text
./checkpoints/
├── model_8000
│ ├── model_config.json
│ ├── model_state.pdparams
│ ├── spm.model
│ ├── tokenizer_config.json
│ └── vocab.txt
└── ...
```

**NOTE:** 如需恢复模型训练,`model_name_or_path`配置本地模型的目录地址即可。

### 模型预测

运行如下命令即可在样例测试集上进行测试

```shell
export CUDA_VISIBLE_DEVICES=0
# GPU启动,预测仅支持单卡
python infer.py \
--model_name_or_path=./checkpoints/model_80000 \
--test_data_path=./datasets/test.txt \
--output_path=./predict.txt \
--logging_steps=500 \
--seed=2021 \
--batch_size=4 \
--min_dec_len=1 \
--max_dec_len=64 \
--num_samples=20 \
--decode_strategy=sampling \
--top_k=5 \
--device=gpu
```

其中参数释义如下:
- `model_name_or_path` 指示了finetune使用的具体预训练模型,可以是PaddleNLP提供的预训练模型,或者是本地的预训练模型。如果使用本地的预训练模型,可以配置本地模型的目录地址,例如: ./checkpoints/model_xx/,目录中需包含paddle预训练模型model_state.pdparams。如果使用PaddleNLP提供的预训练模型,可以选择下面其中之一。

| PaddleNLP提供的预训练模型 |
|---------------------------------|
| unified_transformer-12L-cn |
| unified_transformer-12L-cn-luge |

- `test_data_path` 表示预测集文件路径。
- `output_path` 表示预测结果的保存路径。
- `logging_steps` 表示日志打印间隔。
- `seed` 表示随机数生成器的种子。
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
- `min_dec_len` 表示预测生成的句子的最小长度。
- `max_dec_len` 表示预测生成的句子的最大长度。
- `num_samples` 表示每条样本生成的句子的数量。对于每条样本,模型会生成`num_samples`个句子,根据每个句子的概率得分进行排序,得分最高的句子作为最终的生成结果。
- `decode_strategy` 表示预测解码时采取的策略,可选"sampling"、"greedy_search"和"beam_search"之一。
- `top_k` 表示采用"sampling"解码策略时,token的概率按从大到小排序,生成的token只从前`top_k`个中进行采样。
- `device` 表示训练使用的设备。

参数详情和参数的默认值请参考`args.py`。

程序运行结束后会将预测结果保存在`output_path`中。将预测结果准备成比赛官网要求的格式,提交评估即可得评估结果。

采用不同的模型在样例测试集上有如下结果:

| model_name_or_path | F1 | BLEU1 / BLEU2 | DISTINCT1 / DISTINCT2 |
| :-----------------------------: | :---: | :-----------: | :-------------------: |
| unified_transformer-12L-cn | 10.62 | 0.070 / 0.022 | 0.065 / 0.304 |
| unified_transformer-12L-cn-luge | 33.11 | 0.245 / 0.157 | 0.074 / 0.238 |
| ./checkpoints/model_80000 | 32.38 | 0.239 / 0.150 | 0.070 / 0.219 |
186 changes: 186 additions & 0 deletions examples/text_generation/unimo-text/gen_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import random
from functools import partial

import numpy as np

import paddle
import paddle.distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler
from paddlenlp.data import Pad


def print_args(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')


def set_seed(seed):
# Use the same data seed(for data shuffle) for all procs to guarantee data
# consistency after sharding.
random.seed(seed)
np.random.seed(seed)
# Maybe different op seeds(for dropout) for different procs is better.
paddle.seed(seed + dist.get_rank())


def convert_example(example,
tokenizer,
max_seq_len=512,
max_target_len=128,
max_title_len=256,
mode='train'):
"""Convert all examples into necessary features."""
source = example['source']
title = None
if 'title' in example.keys():
title = example['title']

if mode != 'test':
tokenized_example = tokenizer.gen_encode(
source,
title=title,
target=example['target'],
max_seq_len=max_seq_len,
max_target_len=max_target_len,
max_title_len=max_title_len,
return_position_ids=True,
return_length=True)
target_start = tokenized_example['input_ids'].index(
tokenizer.cls_token_id, 1)
target_end = tokenized_example['seq_len']
# Use to gather the logits corresponding to the labels during training
tokenized_example['masked_positions'] = list(
range(target_start, target_end - 1))
tokenized_example['labels'] = tokenized_example['input_ids'][
target_start + 1:target_end]

return tokenized_example
else:
tokenized_example = tokenizer.gen_encode(
source,
title=title,
max_seq_len=max_seq_len,
max_title_len=max_title_len,
add_start_token_for_decoding=True,
return_position_ids=True)

if 'target' in example:
tokenized_example['target'] = example['target']
return tokenized_example


def batchify_fn(batch_examples, pad_val, mode):
def pad_mask(batch_attention_mask):
batch_size = len(batch_attention_mask)
max_len = max(map(len, batch_attention_mask))
attention_mask = np.ones(
(batch_size, max_len, max_len), dtype='float32') * -1e9
for i, mask_data in enumerate(attention_mask):
seq_len = len(batch_attention_mask[i])
mask_data[-seq_len:, -seq_len:] = np.array(
batch_attention_mask[i], dtype='float32')
# In order to ensure the correct broadcasting mechanism, expand one
# dimension to the second dimension (n_head of Transformer).
attention_mask = np.expand_dims(attention_mask, axis=1)
return attention_mask

pad_func = Pad(pad_val=pad_val, pad_right=False, dtype='int64')

input_ids = pad_func([example['input_ids'] for example in batch_examples])
token_type_ids = pad_func(
[example['token_type_ids'] for example in batch_examples])
position_ids = pad_func(
[example['position_ids'] for example in batch_examples])

attention_mask = pad_mask(
[example['attention_mask'] for example in batch_examples])

if mode != 'test':
max_len = max([example['seq_len'] for example in batch_examples])
masked_positions = np.concatenate([
np.array(example['masked_positions']) +
(max_len - example['seq_len']) + i * max_len
for i, example in enumerate(batch_examples)
])
labels = np.concatenate([
np.array(
example['labels'], dtype='int64') for example in batch_examples
])
return input_ids, token_type_ids, position_ids, attention_mask, masked_positions, labels
else:
return input_ids, token_type_ids, position_ids, attention_mask


def create_data_loader(dataset, tokenizer, args, mode):
trans_func = partial(
convert_example,
tokenizer=tokenizer,
max_seq_len=args.max_seq_len,
max_target_len=args.max_target_len,
max_title_len=args.max_title_len,
mode=mode)
dataset = dataset.map(trans_func, lazy=True)
if mode == 'train':
batch_sampler = DistributedBatchSampler(
dataset, batch_size=args.batch_size, shuffle=True)
else:
batch_sampler = BatchSampler(
dataset, batch_size=args.batch_size // 2, shuffle=False)
collate_fn = partial(batchify_fn, pad_val=tokenizer.pad_token_id, mode=mode)
data_loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
return_list=True)
return dataset, data_loader


def post_process_sum(token_ids, tokenizer):
"""Post-process the decoded sequence. Truncate from the first <eos>."""
eos_pos = len(token_ids)
for i, tok_id in enumerate(token_ids):
if tok_id == tokenizer.mask_token_id:
eos_pos = i
break
token_ids = token_ids[:eos_pos]
tokens = tokenizer.convert_ids_to_tokens(token_ids)
tokens = tokenizer.merge_subword(tokens)
special_tokens = ['[UNK]']
tokens = [token for token in tokens if token not in special_tokens]
return token_ids, tokens


def select_sum(ids, scores, tokenizer, max_dec_len=None,
num_return_sequences=1):
ids = ids.numpy()
scores = scores.numpy()

if len(ids) != len(scores) or (len(ids) % num_return_sequences) != 0:
raise ValueError(
"the length of `ids` is {}, but the `num_return_sequences` is {}".
format(len(ids), num_return_sequences))

group = []
tmp = []
for pred, score in zip(ids, scores):
pred_token_ids, pred_tokens = post_process_sum(pred, tokenizer)
num_token = len(pred_token_ids)

target = "".join(pred_tokens)

# not ending
if max_dec_len is not None and num_token >= max_dec_len:
score -= 1e3

tmp.append([target, score])
if len(tmp) == num_return_sequences:
group.append(tmp)
tmp = []

results = []
for preds in group:
preds = sorted(preds, key=lambda x: -x[1])
results.append(preds[0][0])
return results
Loading