From 740f5e239748ed41e337fc28b94a04d65b3fa5d3 Mon Sep 17 00:00:00 2001 From: smallv0221 <33639025+smallv0221@users.noreply.github.com> Date: Tue, 17 Aug 2021 13:56:27 +0800 Subject: [PATCH] Add unimo model and fix generate api (#891) * fix unified transformer dtype problem * fix win dtype bug * Fix plato-2 and plato-mini dtype bug * Fix plato-2 tokenization * Refine some doc * Add general k support for topk sampling * fix seed * minor fix * Fix unitransformer readme * topk kernel optimization * add unimo model and fix generate api * add 3 datasets for unimo-text Co-authored-by: Jiaqi Liu Co-authored-by: liu zhengxi <380185688@qq.com> --- examples/text_generation/unimo-text/README.md | 144 +++++ .../text_generation/unimo-text/gen_utils.py | 186 ++++++ .../text_generation/unimo-text/run_gen.py | 227 ++++++++ paddlenlp/datasets/advertisegen.py | 68 +++ paddlenlp/datasets/dureader_qg.py | 68 +++ paddlenlp/datasets/lcsts_new.py | 67 +++ paddlenlp/transformers/__init__.py | 2 + paddlenlp/transformers/generation_utils.py | 35 +- paddlenlp/transformers/unimo/__init__.py | 0 paddlenlp/transformers/unimo/modeling.py | 503 ++++++++++++++++ paddlenlp/transformers/unimo/tokenizer.py | 549 ++++++++++++++++++ 11 files changed, 1831 insertions(+), 18 deletions(-) create mode 100644 examples/text_generation/unimo-text/README.md create mode 100644 examples/text_generation/unimo-text/gen_utils.py create mode 100644 examples/text_generation/unimo-text/run_gen.py create mode 100644 paddlenlp/datasets/advertisegen.py create mode 100644 paddlenlp/datasets/dureader_qg.py create mode 100644 paddlenlp/datasets/lcsts_new.py create mode 100644 paddlenlp/transformers/unimo/__init__.py create mode 100644 paddlenlp/transformers/unimo/modeling.py create mode 100644 paddlenlp/transformers/unimo/tokenizer.py diff --git a/examples/text_generation/unimo-text/README.md b/examples/text_generation/unimo-text/README.md new file mode 100644 index 000000000000..29a2e9c7533c --- /dev/null +++ b/examples/text_generation/unimo-text/README.md @@ -0,0 +1,144 @@ +# 千言:面向事实一致性的生成评测比赛baseline + +## 比赛简介 + +自然语言生成旨在让机器能够像人一样使用自然语言进行表达和交互,它是人工智能领域重要的前沿课题,近年来受到学术界和工业界广泛关注。 + +随着神经网络生成模型特别是预训练语言模型的迅速发展,机器生成文本的可读性和流畅性不断提升。然而,自动生成的文本中依然经常出现不符合原文或背景的错误事实描述,这种生成的事实一致性问题是自然语言生成进行落地应用的主要障碍之一,并逐渐受到研究学者的关注。鉴于当前国内外关于事实一致性的生成评测比赛十分匮乏,为了促进自然语言生成的技术发展和实际应用,我们计划组织面向事实一致性的生成评测比赛。 + +在此比赛中,我们将提供三个对事实一致性有较高要求的生成任务,包括文案生成、摘要生成和问题生成。同时,在系统评价中,我们将结合文本流畅性和事实一致性两项指标综合评估参赛生成系统的水平。通过这样的任务设定和评价方式,此评测将有助于研究者和开发者更多关注自然语言生成的事实一致性难题,并为大家提供学术交流平台,从而进一步提升自然语言生成的研究水平,推动相关技术的应用发展。 + +本比赛得到中国中文信息学会自然语言生成专业委员会(筹)支持,将在2021年11月7日首届中国自然语言生成大会(CCNLG-2021)召开评测研讨会,并在大会上对获奖团队颁奖。 + + +## 快速开始 + +### 数据准备 + +比赛使用三个任务数据集测试参赛系统的生成能力,包括文案生成、摘要生成和问题生成: + +- 文案生成根据结构化的商品信息生成合适的广告文案; +- 摘要生成是为输入文档生成简洁且包含关键信息的简洁文本; +- 问题生成则是根据给定段落以及答案生成适合的问题。 + + +### 模型训练 + +运行如下命令即可在样例训练集上进行finetune,并在样例验证集上进行验证 + +```shell +# GPU启动,参数`--gpus`指定训练所用的GPU卡号,可以是单卡,也可以多卡 +unset CUDA_VISIBLE_DEVICES +python -m paddle.distributed.launch --gpus "0" --log_dir ./log run_gen.py \ + --dataset_name=dureader_qg \ + --model_name_or_path=unimo-text-1.0 \ + --save_dir=./unimo/checkpoints \ + --logging_steps=100 \ + --save_steps=100000 \ + --epochs=6 \ + --batch_size=16 \ + --learning_rate=5e-5 \ + --warmup_propotion=0.02 \ + --weight_decay=0.01 \ + --max_seq_len=512 \ + --max_target_len=30 \ + --do_train \ + --do_predict \ + --device=gpu +``` + +其中参数释义如下: +- `gpus` 指示了训练所用的GPU卡号。 +- `dataset_name` 数据集名称,dureader_qg、advertisegen和lcsts_new分别对应问题生成、文案生成和摘要生成三个任务。 +- `model_name_or_path` 指示了finetune使用的具体预训练模型,可以是PaddleNLP提供的预训练模型,或者是本地的预训练模型。如果使用本地的预训练模型,可以配置本地模型的目录地址,例如: ./checkpoints/model_xx/,目录中需包含paddle预训练模型model_state.pdparams。如果使用PaddleNLP提供的预训练模型,可以选择下面其中之一。 + + | PaddleNLP提供的预训练模型 | + |---------------------------------| + | unimo-text-1.0 | + | unimo-text-1.0-large | + +- `save_dir` 表示模型的保存路径。 +- `logging_steps` 表示日志打印间隔。 +- `save_steps` 表示模型保存及评估间隔。 +- `seed` 表示随机数生成器的种子。 +- `epochs` 表示训练轮数。 +- `batch_size` 表示每次迭代**每张卡**上的样本数目。 +- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。 +- `weight_decay` 表示AdamW优化器中使用的weight_decay的系数。 +- `warmup_propotion` 表示学习率逐渐升高到基础学习率(即上面配置的learning_rate)所需要的迭代数占总步数的比例,最早的使用可以参考[这篇论文](https://arxiv.org/pdf/1706.02677.pdf)。 +- `max_seq_len` 模型输入序列的最大长度。 +- `max_target_len` 模型训练时标签的最大长度。 +- `do_train` 是否进行训练。 +- `do_predict` 是否进行预测,在验证集上会自动评估。 +- `device` 表示使用的设备,从gpu和cpu中选择。 + +更多参数详情和参数的默认值请参考`args.py`。 + +程序运行时将会自动进行训练和验证,训练过程中会自动保存模型在指定的`save_dir`中。 +如: +```text +./checkpoints/ +├── model_8000 +│ ├── model_config.json +│ ├── model_state.pdparams +│ ├── spm.model +│ ├── tokenizer_config.json +│ └── vocab.txt +└── ... +``` + +**NOTE:** 如需恢复模型训练,`model_name_or_path`配置本地模型的目录地址即可。 + +### 模型预测 + +运行如下命令即可在样例测试集上进行测试 + +```shell +export CUDA_VISIBLE_DEVICES=0 +# GPU启动,预测仅支持单卡 +python infer.py \ + --model_name_or_path=./checkpoints/model_80000 \ + --test_data_path=./datasets/test.txt \ + --output_path=./predict.txt \ + --logging_steps=500 \ + --seed=2021 \ + --batch_size=4 \ + --min_dec_len=1 \ + --max_dec_len=64 \ + --num_samples=20 \ + --decode_strategy=sampling \ + --top_k=5 \ + --device=gpu +``` + +其中参数释义如下: +- `model_name_or_path` 指示了finetune使用的具体预训练模型,可以是PaddleNLP提供的预训练模型,或者是本地的预训练模型。如果使用本地的预训练模型,可以配置本地模型的目录地址,例如: ./checkpoints/model_xx/,目录中需包含paddle预训练模型model_state.pdparams。如果使用PaddleNLP提供的预训练模型,可以选择下面其中之一。 + + | PaddleNLP提供的预训练模型 | + |---------------------------------| + | unified_transformer-12L-cn | + | unified_transformer-12L-cn-luge | + +- `test_data_path` 表示预测集文件路径。 +- `output_path` 表示预测结果的保存路径。 +- `logging_steps` 表示日志打印间隔。 +- `seed` 表示随机数生成器的种子。 +- `batch_size` 表示每次迭代**每张卡**上的样本数目。 +- `min_dec_len` 表示预测生成的句子的最小长度。 +- `max_dec_len` 表示预测生成的句子的最大长度。 +- `num_samples` 表示每条样本生成的句子的数量。对于每条样本,模型会生成`num_samples`个句子,根据每个句子的概率得分进行排序,得分最高的句子作为最终的生成结果。 +- `decode_strategy` 表示预测解码时采取的策略,可选"sampling"、"greedy_search"和"beam_search"之一。 +- `top_k` 表示采用"sampling"解码策略时,token的概率按从大到小排序,生成的token只从前`top_k`个中进行采样。 +- `device` 表示训练使用的设备。 + +参数详情和参数的默认值请参考`args.py`。 + +程序运行结束后会将预测结果保存在`output_path`中。将预测结果准备成比赛官网要求的格式,提交评估即可得评估结果。 + +采用不同的模型在样例测试集上有如下结果: + +| model_name_or_path | F1 | BLEU1 / BLEU2 | DISTINCT1 / DISTINCT2 | +| :-----------------------------: | :---: | :-----------: | :-------------------: | +| unified_transformer-12L-cn | 10.62 | 0.070 / 0.022 | 0.065 / 0.304 | +| unified_transformer-12L-cn-luge | 33.11 | 0.245 / 0.157 | 0.074 / 0.238 | +| ./checkpoints/model_80000 | 32.38 | 0.239 / 0.150 | 0.070 / 0.219 | diff --git a/examples/text_generation/unimo-text/gen_utils.py b/examples/text_generation/unimo-text/gen_utils.py new file mode 100644 index 000000000000..ca1638279823 --- /dev/null +++ b/examples/text_generation/unimo-text/gen_utils.py @@ -0,0 +1,186 @@ +import random +from functools import partial + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler +from paddlenlp.data import Pad + + +def print_args(args): + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +def set_seed(seed): + # Use the same data seed(for data shuffle) for all procs to guarantee data + # consistency after sharding. + random.seed(seed) + np.random.seed(seed) + # Maybe different op seeds(for dropout) for different procs is better. + paddle.seed(seed + dist.get_rank()) + + +def convert_example(example, + tokenizer, + max_seq_len=512, + max_target_len=128, + max_title_len=256, + mode='train'): + """Convert all examples into necessary features.""" + source = example['source'] + title = None + if 'title' in example.keys(): + title = example['title'] + + if mode != 'test': + tokenized_example = tokenizer.gen_encode( + source, + title=title, + target=example['target'], + max_seq_len=max_seq_len, + max_target_len=max_target_len, + max_title_len=max_title_len, + return_position_ids=True, + return_length=True) + target_start = tokenized_example['input_ids'].index( + tokenizer.cls_token_id, 1) + target_end = tokenized_example['seq_len'] + # Use to gather the logits corresponding to the labels during training + tokenized_example['masked_positions'] = list( + range(target_start, target_end - 1)) + tokenized_example['labels'] = tokenized_example['input_ids'][ + target_start + 1:target_end] + + return tokenized_example + else: + tokenized_example = tokenizer.gen_encode( + source, + title=title, + max_seq_len=max_seq_len, + max_title_len=max_title_len, + add_start_token_for_decoding=True, + return_position_ids=True) + + if 'target' in example: + tokenized_example['target'] = example['target'] + return tokenized_example + + +def batchify_fn(batch_examples, pad_val, mode): + def pad_mask(batch_attention_mask): + batch_size = len(batch_attention_mask) + max_len = max(map(len, batch_attention_mask)) + attention_mask = np.ones( + (batch_size, max_len, max_len), dtype='float32') * -1e9 + for i, mask_data in enumerate(attention_mask): + seq_len = len(batch_attention_mask[i]) + mask_data[-seq_len:, -seq_len:] = np.array( + batch_attention_mask[i], dtype='float32') + # In order to ensure the correct broadcasting mechanism, expand one + # dimension to the second dimension (n_head of Transformer). + attention_mask = np.expand_dims(attention_mask, axis=1) + return attention_mask + + pad_func = Pad(pad_val=pad_val, pad_right=False, dtype='int64') + + input_ids = pad_func([example['input_ids'] for example in batch_examples]) + token_type_ids = pad_func( + [example['token_type_ids'] for example in batch_examples]) + position_ids = pad_func( + [example['position_ids'] for example in batch_examples]) + + attention_mask = pad_mask( + [example['attention_mask'] for example in batch_examples]) + + if mode != 'test': + max_len = max([example['seq_len'] for example in batch_examples]) + masked_positions = np.concatenate([ + np.array(example['masked_positions']) + + (max_len - example['seq_len']) + i * max_len + for i, example in enumerate(batch_examples) + ]) + labels = np.concatenate([ + np.array( + example['labels'], dtype='int64') for example in batch_examples + ]) + return input_ids, token_type_ids, position_ids, attention_mask, masked_positions, labels + else: + return input_ids, token_type_ids, position_ids, attention_mask + + +def create_data_loader(dataset, tokenizer, args, mode): + trans_func = partial( + convert_example, + tokenizer=tokenizer, + max_seq_len=args.max_seq_len, + max_target_len=args.max_target_len, + max_title_len=args.max_title_len, + mode=mode) + dataset = dataset.map(trans_func, lazy=True) + if mode == 'train': + batch_sampler = DistributedBatchSampler( + dataset, batch_size=args.batch_size, shuffle=True) + else: + batch_sampler = BatchSampler( + dataset, batch_size=args.batch_size // 2, shuffle=False) + collate_fn = partial(batchify_fn, pad_val=tokenizer.pad_token_id, mode=mode) + data_loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + return_list=True) + return dataset, data_loader + + +def post_process_sum(token_ids, tokenizer): + """Post-process the decoded sequence. Truncate from the first .""" + eos_pos = len(token_ids) + for i, tok_id in enumerate(token_ids): + if tok_id == tokenizer.mask_token_id: + eos_pos = i + break + token_ids = token_ids[:eos_pos] + tokens = tokenizer.convert_ids_to_tokens(token_ids) + tokens = tokenizer.merge_subword(tokens) + special_tokens = ['[UNK]'] + tokens = [token for token in tokens if token not in special_tokens] + return token_ids, tokens + + +def select_sum(ids, scores, tokenizer, max_dec_len=None, + num_return_sequences=1): + ids = ids.numpy() + scores = scores.numpy() + + if len(ids) != len(scores) or (len(ids) % num_return_sequences) != 0: + raise ValueError( + "the length of `ids` is {}, but the `num_return_sequences` is {}". + format(len(ids), num_return_sequences)) + + group = [] + tmp = [] + for pred, score in zip(ids, scores): + pred_token_ids, pred_tokens = post_process_sum(pred, tokenizer) + num_token = len(pred_token_ids) + + target = "".join(pred_tokens) + + # not ending + if max_dec_len is not None and num_token >= max_dec_len: + score -= 1e3 + + tmp.append([target, score]) + if len(tmp) == num_return_sequences: + group.append(tmp) + tmp = [] + + results = [] + for preds in group: + preds = sorted(preds, key=lambda x: -x[1]) + results.append(preds[0][0]) + return results diff --git a/examples/text_generation/unimo-text/run_gen.py b/examples/text_generation/unimo-text/run_gen.py new file mode 100644 index 000000000000..5ccbfdc9c862 --- /dev/null +++ b/examples/text_generation/unimo-text/run_gen.py @@ -0,0 +1,227 @@ +import os +import time +import math +import argparse +import json + +import paddle +import paddle.distributed as dist +import paddle.nn as nn +import paddle.nn.functional as F +from paddlenlp.transformers import LinearDecayWithWarmup +from paddle.optimizer import AdamW, SGD +from paddlenlp.ops.optimizer import AdamwOptimizer + +from paddlenlp.datasets import load_dataset +from paddlenlp.transformers import UNIMOLMHeadModel, UNIMOTokenizer, BasicTokenizer +from paddlenlp.metrics import BLEU + +from gen_utils import print_args, set_seed, create_data_loader, select_sum + + +# yapf: disable +def parse_args(): + parser = argparse.ArgumentParser(__doc__) + parser.add_argument('--dataset_name', type=str, default='dureader_qg', help='The name of the dataset to load.') + parser.add_argument('--model_name_or_path', type=str, default='unimo-text-1.0', help='The path or shortcut name of the pre-trained model.') + parser.add_argument('--save_dir', type=str, default='./checkpoints', help='The directory where the checkpoints will be saved.') + parser.add_argument('--logging_steps', type=int, default=100, help='Log every X updates steps.') + parser.add_argument('--save_steps', type=int, default=1000, help='Save checkpoint every X updates steps.') + parser.add_argument('--seed', type=int, default=1, help='Random seed for initialization.') + parser.add_argument('--batch_size', type=int, default=16, help='Batch size per GPU/CPU for training.') + parser.add_argument('--learning_rate', type=float, default=5e-5, help='The initial learning rate.') + parser.add_argument('--weight_decay', type=float, default=0.01, help='The weight decay for optimizer.') + parser.add_argument('--epochs', type=int, default=3, help='Total number of training epochs to perform.') + parser.add_argument('--warmup_propotion', type=float, default=0.02, help='The number of warmup steps.') + parser.add_argument('--max_grad_norm', type=float, default=1.0, help='The max value of grad norm.') + parser.add_argument('--beta1', type=float, default=0.9, help='beta1') + parser.add_argument('--beta2', type=float, default=0.98, help='beta2') + parser.add_argument('--epsilon', type=float, default=1e-6, help='epsilon') + parser.add_argument('--max_seq_len', type=int, default=512, help='The maximum sequence length of training.') + parser.add_argument('--max_dec_len', type=int, default=20, help='The maximum sequence length of decoding.') + parser.add_argument('--min_dec_len', type=int, default=3, help='The minimal sequence length of decoding.') + parser.add_argument('--max_target_len', type=int, default=30, help='The maximum target sequence length of training.') + parser.add_argument('--max_title_len', type=int, default=30, help='The maximum title sequence length of training.') + parser.add_argument('--num_return_sequences', type=int, default=1, help='The numbers of returned sequences for one input in generation.') + parser.add_argument('--decode_strategy', type=str, default='beam_search', help='The decode strategy in generation.') + parser.add_argument('--top_k', type=int, default=0, help='The number of highest probability vocabulary tokens to keep for top-k sampling.') + parser.add_argument('--temperature', type=float, default=1.0, help='The value used to module the next token probabilities.') + parser.add_argument('--top_p', type=float, default=1.0, help='The cumulative probability for top-p sampling.') + parser.add_argument('--num_beams', type=int, default=6, help='The number of beams for beam search.') + parser.add_argument('--length_penalty', type=float, default=1.2, help='The exponential penalty to the sequence length for beam search.') + parser.add_argument('--device', type=str, default='gpu', help='The device to select for training the model.') + parser.add_argument('--output_path', type=str, default='./predict.txt', help='The file path where the infer result will be saved.') + parser.add_argument("--do_train", action='store_true', help="Whether to train the model.") + parser.add_argument("--do_predict", action='store_true', help="Whether to eval and predict.") + + args = parser.parse_args() + return args +# yapf: enable + + +def calc_bleu(preds, targets): + assert len(preds) == len(targets), ( + 'The length of pred_responses should be equal to the length of ' + 'target_responses. But received {} and {}.'.format( + len(preds), len(targets))) + bleu4 = BLEU(n_size=4) + tokenizer = BasicTokenizer() + + for pred, target in zip(preds, targets): + pred_tokens = tokenizer.tokenize(pred) + target_token = tokenizer.tokenize(target) + + bleu4.add_inst(pred_tokens, [target_token]) + + print('\n' + '*' * 15) + print('The auto evaluation result is:') + print('BLEU-4:', bleu4.score()) + + +def save_ckpt(model, tokenizer, save_dir, name): + output_dir = os.path.join(save_dir, "model_{}".format(name)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # Need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance(model, + paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + +def run(args): + paddle.set_device(args.device) + world_size = dist.get_world_size() + + if world_size > 1: + dist.init_parallel_env() + set_seed(args.seed) + + model = UNIMOLMHeadModel.from_pretrained(args.model_name_or_path) + tokenizer = UNIMOTokenizer.from_pretrained(args.model_name_or_path) + + if world_size > 1: + model = paddle.DataParallel(model) + + train_ds, dev_ds = load_dataset(args.dataset_name, splits=['train', 'dev']) + + train_ds, train_data_loader = create_data_loader(train_ds, tokenizer, args, + 'train') + dev_ds, dev_data_loader = create_data_loader(dev_ds, tokenizer, args, + 'test') + + if args.do_train: + num_training_steps = args.epochs * len(train_data_loader) + + lr_scheduler = LinearDecayWithWarmup( + args.learning_rate, num_training_steps, args.warmup_propotion) + # Generate parameter names needed to perform weight decay. + # All bias and LayerNorm parameters are excluded. + + decay_params = [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + + optimizer = AdamW( + learning_rate=lr_scheduler, + parameters=model.parameters(), + weight_decay=args.weight_decay, + beta1=args.beta1, + beta2=args.beta2, + epsilon=args.epsilon, + apply_decay_param_fun=lambda x: x in decay_params) + + step = 0 + total_time = 0.0 + for epoch in range(args.epochs): + print('\nEpoch %d/%d' % (epoch + 1, args.epochs)) + batch_start_time = time.time() + for inputs in train_data_loader: + step += 1 + labels = inputs[-1] + logits = model(*inputs[:-1]) + labels = paddle.nn.functional.one_hot( + labels, num_classes=logits.shape[-1]) + labels = paddle.nn.functional.label_smooth(labels) + loss = F.cross_entropy(logits, labels, soft_label=True) + + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + + total_time += (time.time() - batch_start_time) + if step % args.logging_steps == 0: + ppl = paddle.exp(loss) + print( + 'step %d - loss: %.4f - ppl: %.4f - lr: %.7f - %.3fs/step' + % (step, loss, ppl, optimizer.get_lr(), + total_time / args.logging_steps)) + total_time = 0.0 + + if step % args.save_steps == 0 or step >= num_training_steps: + if dist.get_rank() == 0: + save_ckpt(model, tokenizer, args.save_dir, step) + print('Saved step {} model.\n'.format(step)) + if args.do_predict: + evaluation(model, dev_data_loader, args, tokenizer) + + batch_start_time = time.time() + + print('\nTraining completed.') + elif args.do_predict: + evaluation(model, dev_data_loader, args, tokenizer) + + +@paddle.no_grad() +def evaluation(model, data_loader, args, tokenizer): + print('\nEval begin...') + model.eval() + pred_ref = [] + total_time = 0.0 + start_time = time.time() + for step, inputs in enumerate(data_loader, 1): + input_ids, token_type_ids, position_ids, attention_mask = inputs + ids, scores = model.generate( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + max_length=args.max_dec_len, + min_length=args.min_dec_len, + decode_strategy=args.decode_strategy, + num_beams=args.num_beams, + length_penalty=args.length_penalty, + num_return_sequences=args.num_return_sequences, + bos_token_id=tokenizer.cls_token_id, + eos_token_id=tokenizer.mask_token_id) + + total_time += (time.time() - start_time) + if step % 100 == 0: + print('step %d - %.3fs/step' % (step, total_time / 100)) + total_time = 0.0 + + results = select_sum(ids, scores, tokenizer, args.max_dec_len, + args.num_return_sequences) + pred_ref.extend(results) + start_time = time.time() + + with open(args.output_path, 'w', encoding='utf-8') as fout: + for ref in pred_ref: + fout.write(ref + '\n') + + print('\nSave inference result into: %s' % args.output_path) + + if 'target' in data_loader.dataset[0].keys(): + targets = [example['target'] for example in data_loader.dataset] + calc_bleu(pred_ref, targets) + + model.train() + return + + +if __name__ == '__main__': + args = parse_args() + print_args(args) + run(args) diff --git a/paddlenlp/datasets/advertisegen.py b/paddlenlp/datasets/advertisegen.py new file mode 100644 index 000000000000..ace9035936de --- /dev/null +++ b/paddlenlp/datasets/advertisegen.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import json +import os + +from paddle.dataset.common import md5file +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder + +__all__ = ['AdvertiseGen'] + + +class AdvertiseGen(DatasetBuilder): + ''' + This dataset contains 119K pairs of product specifications and the + corresponding advertising text. For more information, please refer + to `https://arxiv.org/abs/1908.06605v2`. + ''' + + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5', 'URL')) + SPLITS = { + 'train': META_INFO( + os.path.join('train.json'), 'c0cc79f912099faa6175d28d3ddafafe', + 'https://paddlenlp.bj.bcebos.com/datasets/AdvertiseGen/train.json'), + 'dev': META_INFO( + os.path.join('dev.json'), '5fda84828628a9722da5436485601df3', + 'https://paddlenlp.bj.bcebos.com/datasets/AdvertiseGen/dev.json') + } + + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash, URL = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or (data_hash and + not md5file(fullname) == data_hash): + get_path_from_url(URL, default_root) + + return fullname + + def _read(self, filename, *args): + with open(filename, "r", encoding="utf8") as f: + data_id = 0 + for line in f: + line = line.strip() + if not line: + continue + json_data = json.loads(line) + + yield { + 'source': json_data["content"], + 'target': json_data["summary"], + 'id': data_id + } + data_id += 1 diff --git a/paddlenlp/datasets/dureader_qg.py b/paddlenlp/datasets/dureader_qg.py new file mode 100644 index 000000000000..737c1971deaa --- /dev/null +++ b/paddlenlp/datasets/dureader_qg.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import json +import os + +from paddle.dataset.common import md5file +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder + +__all__ = ['DuReaderQG'] + + +class DuReaderQG(DatasetBuilder): + ''' + This dataset is made form the machine reading comprehension dataset + (i.e. DuReader robust) for question generation task. + ''' + + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5', 'URL')) + SPLITS = { + 'train': META_INFO( + os.path.join('train.json'), 'a6d96bda4662e657ce644ed0e178fe70', + 'https://paddlenlp.bj.bcebos.com/datasets/DuReaderQG/train.json'), + 'dev': META_INFO( + os.path.join('dev.json'), 'a6bd22b0da0ed8e20784398f507d4acc', + 'https://paddlenlp.bj.bcebos.com/datasets/DuReaderQG/dev.json') + } + + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash, URL = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or (data_hash and + not md5file(fullname) == data_hash): + get_path_from_url(URL, default_root) + + return fullname + + def _read(self, filename, *args): + with open(filename, "r", encoding="utf8") as f: + for line in f: + line = line.strip() + if not line: + continue + + json_data = json.loads(line) + title = json_data.get('answer', None) + + yield { + 'source': json_data["context"], + 'target': json_data["question"], + 'title': title, + 'id': json_data['id'] + } diff --git a/paddlenlp/datasets/lcsts_new.py b/paddlenlp/datasets/lcsts_new.py new file mode 100644 index 000000000000..6e34b6e43627 --- /dev/null +++ b/paddlenlp/datasets/lcsts_new.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import json +import os + +from paddle.dataset.common import md5file +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder + +__all__ = ['LCSTSNew'] + + +class LCSTSNew(DatasetBuilder): + ''' + Large-scale Chinese Short Text Summarization(LCSTS) dataset is + constructed by utilizing the naturally annotated web resources + on Sina Weibo. For more information, please refer + to `https://aclanthology.org/D15-1229.pdf`. + ''' + + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5', 'URL')) + SPLITS = { + 'train': META_INFO( + os.path.join('train.json'), '4e06fd1cfd5e7f0380499df8cbe17237', + 'https://paddlenlp.bj.bcebos.com/datasets/LCSTS_new/train.json'), + 'dev': META_INFO( + os.path.join('dev.json'), '9c39d49d25d5296bdc537409208ddc85', + 'https://paddlenlp.bj.bcebos.com/datasets/LCSTS_new/dev.json') + } + + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash, URL = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or (data_hash and + not md5file(fullname) == data_hash): + get_path_from_url(URL, default_root) + + return fullname + + def _read(self, filename, *args): + with open(filename, "r", encoding="utf8") as f: + for line in f: + line = line.strip() + if not line: + continue + json_data = json.loads(line) + + yield { + 'source': json_data["content"], + 'target': json_data["summary"], + 'id': json_data['id'] + } diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index 07cd3a3621c3..3da207aba5d4 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -55,3 +55,5 @@ from .bart.tokenizer import * from .roformer.modeling import * from .roformer.tokenizer import * +from .unimo.modeling import * +from .unimo.tokenizer import * \ No newline at end of file diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index 0d2e7a693667..200e4592e09f 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -42,11 +42,12 @@ def __len__(self): """ return len(self.beams) - def add(self, hyp, sum_logprobs): + def add(self, hyp, sum_logprobs, origin_len=0): """ Add a new hypothesis to the list. """ - score = sum_logprobs / (hyp.shape[-1]**self.length_penalty) + score = sum_logprobs / (((hyp.shape[-1] - origin_len + 5) / 6) + **self.length_penalty) if len(self) < self.num_beams or score > self.worst_score: self.beams.append((score, hyp)) if len(self) > self.num_beams: @@ -57,7 +58,7 @@ def add(self, hyp, sum_logprobs): else: self.worst_score = min(score, self.worst_score) - def is_done(self, best_sum_logprobs, cur_len): + def is_done(self, best_sum_logprobs, cur_len, origin_len=0): """ If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst one in the heap, then we @@ -68,7 +69,8 @@ def is_done(self, best_sum_logprobs, cur_len): elif self.early_stopping: return True else: - cur_score = best_sum_logprobs / cur_len**self.length_penalty + cur_score = best_sum_logprobs / ( + (cur_len - origin_len + 5) / 6)**self.length_penalty ret = self.worst_score >= cur_score return ret @@ -129,6 +131,7 @@ def process(self, next_scores, next_tokens, next_indices, + origin_len=0, pad_token_id=None, eos_token_id=None): cur_len = input_ids.shape[-1] @@ -175,7 +178,8 @@ def process(self, continue beam_hyp.add( input_ids[batch_beam_idx.numpy().item()].clone(), - next_score.numpy().item()) + next_score.numpy().item(), origin_len) + else: # add next predicted token since it is not eos_token next_beam_scores[batch_idx, beam_idx] = next_score @@ -197,7 +201,7 @@ def process(self, # Check if we are done so that we can save a pad step if all(done) if beam_hyp.is_done(next_scores[batch_idx].max().numpy().item(), - cur_len): + cur_len, origin_len): self._done[batch_idx] = 1 return { @@ -354,10 +358,7 @@ def update_model_kwargs_for_generation(outputs, model_kwargs): if "position_ids" in model_kwargs: position_ids = model_kwargs["position_ids"] model_kwargs["position_ids"] = paddle.concat( - [ - position_ids, - paddle.max(position_ids, axis=-1, keepdim=True) + 1 - ], + [position_ids, position_ids[:, -1].reshape((-1, 1)) + 1], axis=-1) # update attention_mask @@ -412,7 +413,7 @@ def generate(self, top_k=0, top_p=1.0, num_beams=1, - length_penalty=1.0, + length_penalty=0.0, early_stopping=False, bos_token_id=None, eos_token_id=None, @@ -452,11 +453,9 @@ def generate(self, num_beams (int, optional): The number of beams in the "beam_search" strategy. Default to 1. length_penalty (float, optional): The exponential penalty to the - sequence length in the "beam_search" strategy. If - :math:`length\_penalty < 1.0`, the model will generate shorter - sequences. If :math:`length\_penalty > 1.0`, the model will - generate longer sequences. Default to 1.0, which means no - penalty. + sequence length in the "beam_search" strategy. The larger this + param is, the more that the model would generate shorter + sequences. Default to 0.0, which means no penalty. early_stopping (bool, optional): Whether to stop searching in the "beam_search" strategy when at least `num_beams` sentences are finished per batch or not. Default to False. @@ -597,7 +596,7 @@ def generate(self, model_kwargs["use_cache"] = use_cache max_length += input_ids.shape[-1] - + min_length += input_ids.shape[-1] logits_processors = self.get_logits_processor(min_length, eos_token_id) if decode_strategy == 'greedy_search': @@ -672,7 +671,6 @@ def greedy_search(self, input_ids, logits_processors, max_length, # pre-process distribution logits = self.adjust_logits_during_generation(logits) logits = logits_processors(input_ids, logits) - # greedy probs = F.softmax(logits) probs = paddle.log(probs) @@ -857,6 +855,7 @@ def beam_search(self, input_ids, beam_scorer, logits_processors, max_length, next_scores, next_tokens, next_indices, + origin_len=origin_len, pad_token_id=pad_token_id, eos_token_id=eos_token_id, ) beam_scores = beam_outputs["next_beam_scores"] diff --git a/paddlenlp/transformers/unimo/__init__.py b/paddlenlp/transformers/unimo/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/paddlenlp/transformers/unimo/modeling.py b/paddlenlp/transformers/unimo/modeling.py new file mode 100644 index 000000000000..8a18d630a136 --- /dev/null +++ b/paddlenlp/transformers/unimo/modeling.py @@ -0,0 +1,503 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Modeling classes for UNIMO model.""" + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import TransformerEncoder + +from .. import PretrainedModel, register_base_model + +__all__ = [ + "UNIMOPretrainedModel", + 'UNIMOModel', + 'UNIMOLMHeadModel', +] + + +class UNIMOPretrainedModel(PretrainedModel): + """ + An abstract class for pretrained UNIMO models. It provides + UNIMO related `model_config_file`, `resource_files_names`, + `pretrained_resource_files_map`, `pretrained_init_configuration`, + `base_model_prefix` for downloading and loading pretrained models. + + Refer to :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for + more details. + """ + + model_config_file = "model_config.json" + pretrained_init_configuration = { + "unimo-text-1.0": { + "vocab_size": 18000, + "hidden_size": 768, + "num_hidden_layers": 12, + "num_attention_heads": 12, + "intermediate_size": 3072, + "hidden_act": "relu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "normalize_before": False, + "max_position_embeddings": 513, + "type_vocab_size": 4, + "initializer_range": 0.02, + "pad_token_id": 0, + "bos_token_id": 1, + "eos_token_id": 3, + "mask_token_id": 3, + }, + "unimo-text-1.0-large": { + "vocab_size": 12800, + "hidden_size": 1024, + "num_hidden_layers": 24, + "num_attention_heads": 16, + "intermediate_size": 4096, + "hidden_act": "relu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "normalize_before": False, + "max_position_embeddings": 512, + "type_vocab_size": 4, + "initializer_range": 0.02, + "pad_token_id": 0, + "bos_token_id": 1, + "eos_token_id": 3, + "mask_token_id": 3, + }, + } + resource_files_names = {"model_state": "model_state.pdparams"} + pretrained_resource_files_map = { + "model_state": { + "unimo-text-1.0": + "https://paddlenlp.bj.bcebos.com/models/transformers/unimo/unimo-text-1.0.pdparams", + "unimo-text-1.0-large": + "https://paddlenlp.bj.bcebos.com/models/transformers/unimo/unimo-text-1.0-large.pdparams", + } + } + base_model_prefix = "unimo" + + def init_weights(self, layer): + # Initialization hook + if isinstance(layer, (nn.Linear, nn.Embedding)): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.initializer_range + if hasattr(self, "initializer_range") else + self.unimo.config["initializer_range"], + shape=layer.weight.shape)) + + +class UNIMOEmbeddings(nn.Layer): + #Include embeddings from word, position and token_type. + + def __init__(self, + vocab_size, + hidden_size=768, + hidden_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=4): + super(UNIMOEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(max_position_embeddings, + hidden_size) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + + def forward(self, input_ids, token_type_ids, position_ids): + input_embedings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = input_embedings + position_embeddings + token_type_embeddings + return embeddings + + +@register_base_model +class UNIMOModel(UNIMOPretrainedModel): + """ + The bare UNIMO Model outputting raw hidden-states without any + specific head on top. + + This model inherits from + :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. Refer to the + superclass documentation for the generic methods. + + This model is also a `paddle.nn.Layer `__ + subclass. Use it as a regular Paddle Layer and refer to the Paddle + documentation for all matter related to general usage and behavior. + + Args: + vocab_size (int): + Vocabulary size of `inputs_ids` in :class:`UNIMOModel`. + Also is the vocab size of token embedding matrix. + hidden_size (int, optional): + Dimensionality of the embedding layers, encoder layers and pooler + layer. Defaults to 768. + num_hidden_layers (int, optional): + The number of hidden layers in the encoder. Defaults to 12. + num_attention_heads (int, optional): + The number of heads in multi-head attention(MHA). Defaults to 12. + intermediate_size (int, optional): + Dimensionality of the feed-forward layer in the encoder. Input + tensors to feed-forward layers are firstly projected from + `hidden_size` to `intermediate_size`, and then projected back to + `hidden_size`. Typically `intermediate_size` is larger than + `hidden_size`. Defaults to 3072. + hidden_act (str, optional): + The activation function in the feedforward network. Defaults to + "gelu". + hidden_dropout_prob(float, optional): + The dropout probability used in pre-process and post-precess of MHA + and FFN sub-layer. Defaults to 0.1. + attention_probs_dropout_prob (float, optional): + The dropout probability used in MHA to drop some attention target. + Defaults to 0.1. + normalize_before (bool, optional): + Indicate whether to put layer normalization into preprocessing of + MHA and FFN sub-layers. If True, pre-process is layer ormalization + and post-precess includes dropout, residual connection. Otherwise, + no pre-process and post-precess includes dropout, residual + connection, layer normalization. Defaults to True. + max_position_embeddings (int, optional): + The maximum length of input `position_ids`. Defaults to 512. + type_vocab_size (int, optional): + The size of the input `token_type_ids`. Defaults to 2. + initializer_range (float, optional): + The standard deviation of the normal initializer. Defaults to 0.02. + + .. note:: + A normal_initializer initializes weight matrices as normal + distributions. See + :meth:`UNIMOPretrainedModel.init_weights` method + for how weights are initialized in + :class:`UNIMOModel`. + unk_token_id (int, optional): + The id of special token `unk_token`. Defaults to 0. + pad_token_id (int, optional): + The id of special token `pad_token`. Defaults to 0. + bos_token_id (int, optional): + The id of special token `bos_token`. Defaults to 1. + eos_token_id (int, optional): + The id of special token `eos_token`. Defaults to 2. + mask_token_id (int, optional): + The id of special token `mask_token`. Defaults to 30000. + """ + + def __init__( + self, + vocab_size, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='relu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + normalize_before=False, + max_position_embeddings=513, + type_vocab_size=4, + initializer_range=0.02, + unk_token_id=0, + pad_token_id=0, + bos_token_id=1, + eos_token_id=3, + mask_token_id=3, ): + super(UNIMOModel, self).__init__() + self.unk_token_id = unk_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.mask_token_id = mask_token_id + self.initializer_range = initializer_range + + self.embeddings = UNIMOEmbeddings( + vocab_size, hidden_size, hidden_dropout_prob, + max_position_embeddings, type_vocab_size) + encoder_layer = nn.TransformerEncoderLayer( + hidden_size, + num_attention_heads, + intermediate_size, + dropout=hidden_dropout_prob, + activation=hidden_act, + attn_dropout=attention_probs_dropout_prob, + act_dropout=0, + normalize_before=normalize_before) + + self.encoder_norm = nn.LayerNorm(hidden_size) + self.dropout = nn.Dropout(hidden_dropout_prob) + self.encoder = nn.TransformerEncoder( + encoder_layer, + num_hidden_layers, ) + + self.apply(self.init_weights) + + def forward(self, + input_ids, + token_type_ids, + position_ids, + attention_mask, + use_cache=False, + cache=None): + r""" + The UNIMOModel forward method, overrides the special + :meth:`__call__` method. + + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input + sequence. It's data type should be `int64` and has a shape of + [batch_size, sequence_length]. + token_type_ids (Tensor): + Segment token indices to indicate first and second portions of + the inputs. Indices can be either 0 or 1: + + - 0 corresponds to a **sentence A** token, + - 1 corresponds to a **sentence B** token. + + It's data type should be `int64` and has a shape of + [batch_size, sequence_length]. + position_ids (Tensor): + The position indices of input sequence tokens. It's data type + should be `int64` and has a shape of [batch_size, sequence_length]. + attention_mask (Tensor): + A tensor used in multi-head attention to prevents attention to + some unwanted positions, usually the paddings or the subsequent + positions. It is a tensor with shape broadcasted to + [batch_size, n_head, sequence_length, sequence_length]. + + - When the data type is bool, the unwanted positions have + `False` values and the others have `True` values. + - When the data type is int, the unwanted positions have 0 + values and the others have 1 values. + - When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. + + use_cache: (bool, optional): + Whether or not use the model cache to speed up decoding. Defaults + to False. + cache (list, optional): + It is a list, and each element in the list is `incremental_cache` + produced by :meth:`paddle.nn.TransformerEncoderLayer.gen_cache` + method. See :meth:`paddle.nn.TransformerEncoder.gen_cache` + method for more details. It is only used for inference and + should be None for training. Defaults to None. + + Returns: + Tensor|tuple: If `use_cache` is False, it is a tensor + representing the output of :class:`UNIMOModel`, with + shape [batch_size, sequence_length, hidden_size]. The data type is + float32 or float64. Otherwise, it is a tuple, besides the output of + :class:`UNIMOModel`, the tuple also includes the new + cache which is same as input `cache` but `incremental_cache` in it + has an incremental length. + See :meth:`paddle.nn.MultiHeadAttention.gen_cache` method and + :meth:`paddle.nn.MultiHeadAttention.forward` method for more details. + + Example: + .. code-block:: + + from paddlenlp.transformers import UNIMOModel + from paddlenlp.transformers import UNIMOTokenizer + + model = UNIMOModel.from_pretrained('plato-mini') + tokenizer = UNIMOTokenizer.from_pretrained('plato-mini') + + source = '我爱祖国' + inputs = tokenizer.gen_encode( + source, + return_tensors=True, + is_split_into_words=False) + outputs = model(**inputs) + """ + + embedding_output = self.embeddings(input_ids, token_type_ids, + position_ids) + + embedding_output = self.encoder_norm(embedding_output) + embedding_output = self.dropout(embedding_output) + + if use_cache: + if cache is None: + cache = self.encoder.gen_cache(embedding_output) + sequence_output, cache = self.encoder(embedding_output, + attention_mask, cache) + return sequence_output, cache + else: + sequence_output = self.encoder(embedding_output, attention_mask) + return sequence_output + + +class UNIMOLMHead(nn.Layer): + def __init__(self, + hidden_size, + vocab_size, + activation, + embedding_weights=None): + super(UNIMOLMHead, self).__init__() + self.transform = nn.Linear(hidden_size, hidden_size) + self.activation = getattr(nn.functional, activation) + self.layer_norm = nn.LayerNorm(hidden_size) + self.decoder_weight = self.create_parameter( + shape=[vocab_size, hidden_size], + dtype=self.transform.weight.dtype, + is_bias=False) if embedding_weights is None else embedding_weights + self.decoder_bias = self.create_parameter( + shape=[vocab_size], dtype=self.decoder_weight.dtype, is_bias=True) + + def forward(self, hidden_states, masked_positions=None): + if masked_positions is not None: + hidden_states = paddle.reshape(hidden_states, + [-1, hidden_states.shape[-1]]) + hidden_states = paddle.tensor.gather(hidden_states, + masked_positions) + hidden_states = self.transform(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.layer_norm(hidden_states) + logits = paddle.tensor.matmul( + hidden_states, self.decoder_weight, + transpose_y=True) + self.decoder_bias + return logits + + +class UNIMOLMHeadModel(UNIMOPretrainedModel): + """ + The UNIMO Model with a language modeling head on top (linear + layer with weights tied to the input embeddings) for generation tasks. + + Args: + unimo (:class:`UNIMOModel`): + An instance of :class:`UNIMOModel`. + """ + + def __init__(self, unimo): + super(UNIMOLMHeadModel, self).__init__() + self.unimo = unimo + self.lm_head = UNIMOLMHead(self.unimo.config["hidden_size"], + self.unimo.config["vocab_size"], + self.unimo.config["hidden_act"], + self.unimo.embeddings.word_embeddings.weight) + self.apply(self.init_weights) + + def forward(self, + input_ids, + token_type_ids, + position_ids, + attention_mask, + masked_positions=None, + use_cache=False, + cache=None): + r""" + The UNIMOLMHeadModel forward method, overrides the special + :meth:`__call__` method. + + Args: + input_ids (Tensor): + See :class:`UNIMOModel`. + token_type_ids (Tensor): + See :class:`UNIMOModel`. + position_ids (Tensor): + See :class:`UNIMOModel`. + attention_mask (Tensor): + See :class:`UNIMOModel`. + use_cache: (bool, optional): + See :class:`UNIMOModel`. + cache (list, optional): + See :class:`UNIMOModel`. + + Returns: + Tensor|tuple: If `use_cache` is False, it is a tensor + representing the output of :class:`UNIMOLMHeadModel`, + with shape [batch_size, sequence_length, vocab_size]. The data type + is float32 or float64. Otherwise, it is a tuple, besides the output + of :class:`UNIMOLMHeadModel`, the tuple also includes + the new cache which is same as input `cache` but `incremental_cache` + in it has an incremental length. + See :meth:`paddle.nn.MultiHeadAttention.gen_cache` method and + :meth:`paddle.nn.MultiHeadAttention.forward` method for more details. + + Example: + .. code-block:: + + from paddlenlp.transformers import UNIMOLMHeadModel + from paddlenlp.transformers import UNIMOTokenizer + + model = UNIMOLMHeadModel.from_pretrained('unimo-text-1.0') + tokenizer = UNIMOTokenizer.from_pretrained('unimo-text-1.0') + + source = '我爱祖国' + inputs = tokenizer.gen_encode( + source, + return_tensors=True, + is_split_into_words=False) + logits = model(**inputs) + """ + + outputs = self.unimo(input_ids, token_type_ids, position_ids, + attention_mask, use_cache, cache) + sequence_output = outputs[0] if use_cache else outputs + #print('sequence_output:',sequence_output) + logits = self.lm_head(sequence_output, masked_positions) + if use_cache: + cache = outputs[1] + return logits, cache + else: + return logits + + def adjust_logits_during_generation(self, logits): + # pre-process distribution + logits[:, self.unimo.unk_token_id] = -1e9 + logits[:, self.unimo.bos_token_id] = -1e9 + return logits + + def prepare_inputs_for_generation(self, + input_ids, + token_type_ids, + position_ids, + attention_mask, + use_cache=False, + cache=None, + **kwargs): + # only last token for inputs_ids if cache is defined in kwargs + if cache is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -1].unsqueeze(-1) + attention_mask = attention_mask[:, :, -1, :].unsqueeze(2) + + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "position_ids": position_ids, + "attention_mask": attention_mask, + "use_cache": use_cache, + "cache": cache + } + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError as e: + try: + return getattr(getattr(self, self.base_model_prefix), name) + except AttributeError: + try: + return getattr(self, self.base_model_prefix).config[name] + except KeyError: + raise e diff --git a/paddlenlp/transformers/unimo/tokenizer.py b/paddlenlp/transformers/unimo/tokenizer.py new file mode 100644 index 000000000000..89d5a284932b --- /dev/null +++ b/paddlenlp/transformers/unimo/tokenizer.py @@ -0,0 +1,549 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +import six +import shutil +import paddle +from paddle.utils import try_import +from paddlenlp.utils.env import MODEL_HOME +import numpy as np +from ...data.vocab import Vocab + +from .. import BasicTokenizer, PretrainedTokenizer, WordpieceTokenizer + +__all__ = ['UNIMOTokenizer'] + + +class UNIMOTokenizer(PretrainedTokenizer): + r""" + Constructs an ERNIE tokenizer. It uses a basic tokenizer to do punctuation + splitting, lower casing and so on, and follows a WordPiece tokenizer to + tokenize as subwords. + + Args: + vocab_file (str): + file path of the vocabulary. + do_lower_case (str, optional): + Whether the text strips accents and convert to lower case. + Defaults to `True`. + unk_token (str, optional): + The special token for unknown words. + Defaults to "[UNK]". + sep_token (str, optional): + The special token for separator token. + Defaults to "[SEP]". + pad_token (str, optional): + The special token for padding. + Defaults to "[PAD]". + cls_token (str, optional): + The special token for cls. + Defaults to "[CLS]". + mask_token (str, optional): + The special token for mask. + Defaults to "[MASK]". + + Examples: + .. code-block:: python + from paddlenlp.transformers import UNIMOTokenizer + tokenizer = UNIMOTokenizer.from_pretrained('unimo-text-1.0') + encoded_inputs = tokenizer('这是一个测试样例') + # encoded_inputs: + # { + # 'input_ids': [1, 47, 10, 7, 27, 558, 525, 314, 656, 2], + # 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # } + + + """ + resource_files_names = {"vocab_file": "vocab.txt"} # for save_pretrained + pretrained_resource_files_map = { + "vocab_file": { + "unimo-text-1.0": + "https://paddlenlp.bj.bcebos.com/models/transformers/unimo/unimo-text-1.0-vocab.txt", + "unimo-text-1.0-large": + "https://paddlenlp.bj.bcebos.com/models/transformers/unimo/unimo-text-1.0-vocab-large.txt", + } + } + pretrained_init_configuration = { + "unimo-text-1.0": { + "do_lower_case": True + }, + "unimo-text-1.0-large": { + "do_lower_case": True + }, + } + + def __init__(self, + vocab_file, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]"): + + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the " + "vocabulary from a pretrained model please use " + "`tokenizer = UNIMOTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + .format(vocab_file)) + self.vocab = self.load_vocabulary(vocab_file, unk_token=unk_token) + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer( + vocab=self.vocab, unk_token=unk_token) + + @property + def vocab_size(self): + r""" + return the size of vocabulary. + + Returns: + int: the size of vocabulary. + """ + return len(self.vocab) + + @staticmethod + def load_vocabulary(filepath, + unk_token=None, + pad_token=None, + bos_token=None, + eos_token=None, + **kwargs): + token_to_idx = {} + with open(filepath, 'r', encoding='utf-8') as f: + for line in f: + token, index = line.rstrip('\n').split('\t') + token_to_idx[token] = int(index) + vocab = Vocab.from_dict( + token_to_idx, + unk_token=unk_token, + pad_token=pad_token, + bos_token=bos_token, + eos_token=eos_token, + **kwargs) + return vocab + + def _tokenize(self, text): + r""" + End-to-end tokenization for ERNIE models. + + Args: + text (str): The text to be tokenized. + + Returns: + List[str]: A list of string representing converted tokens. + """ + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + return split_tokens + + def tokenize(self, text): + r""" + End-to-end tokenization for ERNIE models. + + Args: + text (str): The text to be tokenized. + + Returns: + List[str]: A list of string representing converted tokens. + """ + return self._tokenize(text) + + def convert_tokens_to_string(self, tokens): + r""" + Converts a sequence of tokens (list of string) in a single string. Since + the usage of WordPiece introducing `##` to concat subwords, also remove + `##` when converting. + + Args: + tokens (List[str]): A list of string representing tokens to be converted. + + Returns: + str: Converted string from tokens. + """ + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def num_special_tokens_to_add(self, pair=False): + r""" + Returns the number of added tokens when encoding a sequence with special tokens. + + Note: + This encodes inputs and checks the number of added tokens, and is therefore not efficient. + Do not put this inside your training loop. + + Args: + pair (str, optional): Returns the number of added tokens in the case of a sequence + pair if set to True, returns the number of added tokens in the case of a single sequence + if set to False. Defaults to False. + + Returns: + `int`: Number of tokens added to sequences + """ + token_ids_0 = [] + token_ids_1 = [] + return len( + self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 + if pair else None)) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + r""" + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. + + A ERNIE sequence has the following format: + :: + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + + Args: + token_ids_0 (List[int]): + List of IDs to which the special tokens will be added. + token_ids_1 (List[int], optional): + Optional second list of IDs for sequence pairs. + Defaults to `None`. + + Returns: + List[int]: List of input_id with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + _cls = [self.cls_token_id] + _sep = [self.sep_token_id] + return _cls + token_ids_0 + _sep + token_ids_1 + _sep + + def merge_subword(self, tokens): + """merge subwords""" + ret = [] + for token in tokens: + if token.startswith("##"): + real_token = token[2:] + if len(ret): + ret[-1] += real_token + else: + ret.append(real_token) + else: + ret.append(token) + + return ret + + def build_offset_mapping_with_special_tokens(self, + offset_mapping_0, + offset_mapping_1=None): + r""" + Build offset map from a pair of offset map by concatenating and adding offsets of special tokens. + + A ERNIE offset_mapping has the following format: + :: + - single sequence: ``(0,0) X (0,0)`` + - pair of sequences: `(0,0) A (0,0) B (0,0)`` + + Args: + offset_mapping_ids_0 (List[tuple]): + List of char offsets to which the special tokens will be added. + offset_mapping_ids_1 (List[tuple], optional): + Optional second list of char offsets for offset mapping pairs. + Defaults to `None`. + + Returns: + List[tuple]: List of char offsets with the appropriate offsets of special tokens. + """ + if offset_mapping_1 is None: + return [(0, 0)] + offset_mapping_0 + [(0, 0)] + + return [(0, 0)] + offset_mapping_0 + [(0, 0) + ] + offset_mapping_1 + [(0, 0)] + + def create_token_type_ids_from_sequences(self, + token_ids_0, + token_ids_1=None): + r""" + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + + A ERNIE sequence pair mask has the following format: + :: + + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (List[int]): + List of IDs. + token_ids_1 (List[int], optional): + Optional second list of IDs for sequence pairs. + Defaults to `None`. + + Returns: + List[int]: List of token_type_id according to the given sequence(s). + """ + _sep = [self.sep_token_id] + _cls = [self.cls_token_id] + if token_ids_1 is None: + return len(_cls + token_ids_0 + _sep) * [0] + return len(_cls + token_ids_0 + _sep) * [0] + len(token_ids_1 + + _sep) * [1] + + def gen_encode(self, + source, + title=None, + target=None, + max_seq_len=512, + max_title_len=128, + max_target_len=128, + return_position_ids=True, + return_token_type_ids=True, + return_attention_mask=True, + return_length=False, + add_start_token_for_decoding=False, + pad_to_max_seq_len=False, + return_tensors=False, + is_split_into_words=False, + continuous_position=False): + """ + Main method to encode the single-turn or multi-turn dialogue conversation. + It will return a dictionary containing the encoded sequence and other + relative informations which meets the input format requirements of the + UnifiedTransformer model. + See detail at + https://github.com/PaddlePaddle/Knover/tree/luge-dialogue/luge-dialogue + + Args: + source (str): The source of dialogue conversation. It + is an utterance or list of utterances to be encoded. Each + utterance is a string. + target (str, optional): The target of dialogue conversation. + It should be set when training the model. It should not be set + when running inference. Defaults to None. + title (str, optional): The title information of dialogue + conversation. It should be set if the `task_type` is "title" + or "recommend". Defaults to None. + task_type (str, optional): The type of dialogue conversation. It is + one of "chitchat", "title" and "recommend". They represent + the chitchat dialogue, title grounded dialogue and + conversational recommendation respectively. Defaults to None, + which means there is no `special_token` added in output sequence + for identifying different conversation types. + max_seq_len (int, optional): The maximum encoded sequence length. + Defaults to 512. + max_target_len (int, optional): The maximum encoded sequence + length of the input `target`. Defaults to 128. + max_title_len (int, optional): The maximum encoded sequence + length of the input `title`. Defaults to 128. + return_position_ids (bool, optional): Whether to return the + position_ids. Defaults to True. + return_token_type_ids (bool, optional): Whether to return the + token_type_ids. Defaults to True. + return_attention_mask (bool, optional): Whether to return the + attention_mask. Defaults to True. + return_length (bool, optional): Whether to return the length of the + encoded sequence. Defaults to False. + add_start_token_for_decoding (bool, optional): Whether to add the + special token "[CLS]" at the end of sequence as the begining of + the target when running inference to force the model to start + generating target sequence. Defaults to False. + pad_to_max_seq_len (bool, optional): Whether to pad the returned + sequences to the `max_seq_len`. Note that, in this method, + returned sequences will be padded on the left. Defaults to False. + return_tensors (bool, optional): Whether to convert the returned + sequences to Tensor. Defaults to False. + is_split_into_words(bool, optinal): Whether or not the input text + (`source`, `target` and `title`) has been pretokenized. + Defaults to True. + + Returns: + dict: A dictionary containing the encoded sequence and other + relative informations. + + With the corresponding fields: + + - input_ids (list[int]|Tensor): + A list of indices of input tokens to be feed to UnifiedTransformer + model. If `return_tensors` is True, it is a Tensor with shape + [1, sequence_length] and data type 'int64'. + - token_type_ids (list[int]|Tensor, optional): + A list of segment token indices to indicate whether the token + belongs to the dialogue target. If `return_tensors` is True, + it is a Tensor with shape [1, sequence_length] and data type + 'int64'. + Being returned when `return_token_type_ids` is set to True. + - position_ids (list[int]|Tensor, optional): + A list of The position indices. If `return_tensors` is True, + it is a Tensor with shape [1, sequence_length] and data type + 'int64'. + Being returned when `return_position_ids` is set to True. + - attention_mask (numpy.ndarray|Tensor, optional): + A numpy.ndarray to prevents attention to some unwanted positions, + with shape [sequence_length, sequence_length] and data type + 'float32'. If `return_tensors` is True, it is a Tensor with shape + [1, 1, sequence_length, sequence_length] and data type 'float32'. + Being returned when `return_attention_mask` is set to True. + - seq_len (int, optional): + The actual length of the `input_ids`, excluding the pad token. + Being returned when `return_length` is set to True. + + Example: + .. code-block:: + + from paddlenlp.transformers import UnifiedTransformerTokenizer + + tokenizer = UnifiedTransformerTokenizer.from_pretrained('plato-mini') + + inputs = tokenizer.dialogue_encode('我爱祖国') + for key in inputs: + print(key + ':') + print(inputs[key]) + # input_ids: [1, 6, 25445, 26907, 25475, 2] + # token_type_ids: [0, 0, 0, 0, 0, 0] + # position_ids: [0, 1, 2, 3, 4, 5] + # attention_mask: [[0. 0. 0. 0. 0. 0.] + # [0. 0. 0. 0. 0. 0.] + # [0. 0. 0. 0. 0. 0.] + # [0. 0. 0. 0. 0. 0.] + # [0. 0. 0. 0. 0. 0.] + # [0. 0. 0. 0. 0. 0.]] + """ + + # Input type checking for clearer error + assert isinstance(source, str), ( + "The input `source` must be with type `str` (single context). " + " But received: {}".format(source)) + assert target is None or isinstance(target, str), ( + "The input `target` must of be with type `str`. But received: {}". + format(target)) + assert title is None or isinstance(title, str), ( + "The input `title` must of be with type `str`. But received: {}". + format(title)) + assert max_seq_len > max_title_len + max_target_len, ( + "`max_seq_len` must be greater than the sum of `max_target_len` " + "and `max_title_len`. But received `max_seq_len` is {}, " + "`max_target_len` is {}, `max_title_len` is {}.".format( + max_seq_len, max_title_len, max_target_len)) + assert target is None or not add_start_token_for_decoding, ( + "`add_start_token_for_decoding` only works when `target` is " + "`None`. But received `add_start_token_for_decoding`: `{}`, " + "`target`: {}.".format(add_start_token_for_decoding, target)) + + title_ids = [] + if title is not None: + tokens = self._tokenize(title) + title_ids = self.convert_tokens_to_ids(tokens) + if len(title_ids) > max_title_len - 1: + title_ids = title_ids[:max_title_len - 1] + title_ids += [self.sep_token_id] + + target_ids = [] + if target is not None: + tokens = self._tokenize(target) + target_ids = [self.cls_token_id] + self.convert_tokens_to_ids( + tokens) + if len(target_ids) > max_target_len - 1: + target_ids = target_ids[:max_target_len - 1] + target_ids += [self.mask_token_id] + elif add_start_token_for_decoding: + target_ids = [self.cls_token_id] + + title_ids = [self.cls_token_id] + title_ids + + max_source_len = max_seq_len - len(title_ids) - len(target_ids) + source_ids = [] + tokens = self._tokenize(source) + source_ids = self.convert_tokens_to_ids(tokens) + + if len(source_ids) > max_source_len - 1: + source_ids = source_ids[:max_source_len - 1] + + source_ids += [self.sep_token_id] + source_ids = title_ids + source_ids + # Build output dictionnary + + encoded_inputs = {} + encoded_inputs["input_ids"] = source_ids + target_ids + # Check lengths + sequence_length = len(encoded_inputs["input_ids"]) + assert sequence_length <= max_seq_len + + # Considering that the logits at the last time step in the API of + # generative task are taken to generate the next token. In order to + # avoid the last time step being a pad, so take padding on the left. + pad_length = max_seq_len - sequence_length if pad_to_max_seq_len else 0 + if pad_length > 0: + encoded_inputs["input_ids"] = [ + self.pad_token_id + ] * pad_length + encoded_inputs["input_ids"] + if return_tensors: + # Add dimention for batch_size + encoded_inputs["input_ids"] = paddle.to_tensor(encoded_inputs[ + "input_ids"]).unsqueeze(0) + + if return_token_type_ids: + encoded_inputs["token_type_ids"] = [0] * len( + source_ids) + [1] * len(target_ids) + if pad_length > 0: + encoded_inputs["token_type_ids"] = [ + self.pad_token_id + ] * pad_length + encoded_inputs["token_type_ids"] + if return_tensors: + # Add dimention for batch_size + encoded_inputs["token_type_ids"] = paddle.to_tensor( + encoded_inputs["token_type_ids"]).unsqueeze(0) + + if return_length: + encoded_inputs["seq_len"] = sequence_length + + if return_position_ids: + if continuous_position: + encoded_inputs["position_ids"] = list(range(sequence_length)) + else: + encoded_inputs["position_ids"] = list(range(len( + source_ids))) + list(range(len(target_ids))) + if pad_length > 0: + encoded_inputs["position_ids"] = [ + self.pad_token_id + ] * pad_length + encoded_inputs["position_ids"] + if return_tensors: + # Add dimention for batch_size + encoded_inputs["position_ids"] = paddle.to_tensor( + encoded_inputs["position_ids"]).unsqueeze(0) + + if return_attention_mask: + attention_mask = np.ones( + (sequence_length, sequence_length), dtype='float32') * -1e9 + start = len(source_ids) + end = sequence_length + attention_mask[:end, :start] = 0.0 + # Generate the lower triangular matrix using the slice of matrix + tmp = np.triu( + np.ones( + [end - start, end - start], dtype='float32') * -1e9, 1) + attention_mask[start:end, start:end] = tmp + encoded_inputs["attention_mask"] = attention_mask + if pad_length > 0: + new_mask = np.ones( + (max_seq_len, max_seq_len), dtype='float32') * -1e9 + new_mask[-sequence_length:, -sequence_length:] = attention_mask + encoded_inputs["attention_mask"] = new_mask + if return_tensors: + # Add dimentions for batch_size and num_heads + encoded_inputs["attention_mask"] = paddle.to_tensor( + encoded_inputs["attention_mask"]).unsqueeze((0, 1)) + + return encoded_inputs