Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#23 from ZHUI/gpt2/add_eval
Browse files Browse the repository at this point in the history
[GPT-2] Add the gpt-2 eval scripts.
  • Loading branch information
wawltor committed Feb 22, 2021
2 parents a4309db + 668544d commit 687de91
Show file tree
Hide file tree
Showing 4 changed files with 341 additions and 4 deletions.
35 changes: 32 additions & 3 deletions examples/language_model/gpt2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

```text
.
├── args.py # 训练参数配置
├── data.py # 数据处理
├── decompress.sh # 数据集解压脚本
├── generate_sample.py # 生成文本示例demo
├── lr.py # 学习率控制
├── process_data.py # 数据预处理脚本
├── README.md # 文档
├── run_pretrain.py # 预训练入口
├── run_eval.py # 评估入口
└── scripts # 训练脚本
```

Expand All @@ -29,7 +31,7 @@

```shell
pip install paddlenlp==2.0.0rc
pip install regex sentencepiece
pip install regex sentencepiece tqdm
```

### 数据准备
Expand Down Expand Up @@ -73,6 +75,7 @@ mkdir data
mv train.data.json_ids.npz data
```
### 模型训练
#### 单卡训练
Expand Down Expand Up @@ -105,7 +108,7 @@ CUDA_VISIBLE_DEVICES=0 python run_pretrain.py \

用户也可以使用提供的shell脚本直接训练`sh scripts/run.sh`.

### 单机多卡
#### 单机多卡

同样,可以执行如下命令实现八卡训练:

Expand All @@ -128,7 +131,33 @@ python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_pretrain.py \

用户也可以使用提供的shell脚本直接训练`sh scripts/run_multi.sh`.

#### 文本生成
### 模型评估

我们提供了对[WikiText](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip)[LAMBADA](https://raw.githubusercontent.com/cybertronai/bflm/master/lambada_test.jsonl)两种数据集的评估脚本, 使用如下命令启动评估:

1. WikiText数据集评估
```bash
python run_eval.py --model_name_or_path gpt2-medium-en \
--eval_path ./wikitext-103/wiki.valid.tokens \
--overlapping_eval 32 \
--init_checkpoint_path ./checkpoint_dir/model_state.pdparams \
--batch_size 8 \
--device gpu
```

2. LAMBADA数据集评估
```bash
python run_eval.py --model_name_or_path gpt2-medium-en \
--eval_path ./lambada_test.jsonl \
--cloze_eval \
--init_checkpoint_path ./checkpoint_dir/model_state.pdparams \
--batch_size 8 \
--device gpu
```
其中数据集WikiText采用的是PPL(perplexity)评估指标,LAMBADA采用的是ACC(accuracy)指标。不设置`init_checkpoint_path` 参数时,可以评估默认预训练好的模型参数。


### 文本生成

本项目提供了简单的文本生成的demo,供用户测试文本生成效果。

Expand Down
1 change: 0 additions & 1 deletion examples/language_model/gpt2/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def __getitem__(self, index):
offset_l = self.sample_idx[idx + 1][1]
tokens = self._get_single_sample_from_idx(doc_index_f, doc_index_l,
offset_f, offset_l)
token_arr = np.array(tokens, dtype="int64")
return self._construct_sample(tokens)

def __len__(self):
Expand Down
308 changes: 308 additions & 0 deletions examples/language_model/gpt2/run_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import os
import json
import math
import time
import argparse

import numpy as np
import paddle
from paddle.io import DataLoader, Dataset
from paddlenlp.transformers import GPT2Model, GPT2ForPretraining
from paddlenlp.transformers import GPT2Tokenizer
from paddlenlp.transformers import GPT2Model
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.utils.log import logger

MODEL_CLASSES = {
"gpt2-small-en": (GPT2ForPretraining, GPT2Tokenizer),
"gpt2-medium-en": (GPT2ForPretraining, GPT2Tokenizer),
"gpt2-large-en": (GPT2ForPretraining, GPT2Tokenizer),
}

# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: "
+ ", ".join(sum([list(classes[-1].pretrained_init_configuration.keys()) for classes in MODEL_CLASSES.values()], [])), )
parser.add_argument("--eval_path", default=None, type=str, required=True, help="The eval file path.", )
parser.add_argument('--cloze_eval', action='store_true', help='Evaluation dataset from `--eval_path` is a cloze task')
parser.add_argument('--overlapping_eval', type=int, default=32, help='Sliding window for overlapping eval ')
parser.add_argument("--init_checkpoint_path", default=None, type=str, help="The model checkpoint path.", )
parser.add_argument( "--batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.", )
parser.add_argument('--seq_length', type=int, default=1024, help='Maximum sequence length to process for evaluation.')
parser.add_argument("--device", type=str, default="gpu", help="Select cpu, gpu, xpu devices.")
parser.add_argument("--logging_steps", type=int, default=100, help="Log every X updates steps.")
# yapf: enable


class LM_Eval_Dataset(paddle.io.Dataset):
def __init__(self, tokens, seq_len, pad_idx, overlapping_eval=None):
self.tokens = tokens
self.seq_len = seq_len
self.pad_idx = pad_idx
self.overlapping_eval = overlapping_eval
if self.overlapping_eval is None:
self.overlapping_eval = self.seq_len
self.overlapping_eval = max(1, self.overlapping_eval)

self.total_targets = len(self.tokens) - 1
# remove first sequence tokens
targets = max(self.total_targets - self.overlapping_eval, 0)
self.total_sequences = max(
math.ceil(targets / self.overlapping_eval) + 1, 1)

def __len__(self):
return self.total_sequences

def _construct_sample(self, tokens):
tokens = np.array(tokens).astype("int64").tolist()
labels = tokens[1:]
tokens = tokens[:-1]
seq_length = len(tokens)
# attention mask for the attention calulate
attention_mask = np.tri(seq_length, seq_length).reshape(
(1, seq_length, seq_length))

# the pad and eod tokens do not contribute the loss
loss_mask = np.ones(seq_length, dtype="float32")
loss_mask[np.where(np.array(tokens) == self.pad_idx)] = 0.0
position_ids = np.arange(0, seq_length, dtype="int64")

# -INF mask value as default
attention_mask = (attention_mask - 1.0) * 1e9
# Bool mask of attention
attention_mask = attention_mask.astype("float32")
return [tokens, loss_mask, attention_mask, position_ids, labels]

def __getitem__(self, idx):
start_idx = idx * self.overlapping_eval
end_idx = start_idx + self.seq_len
tokens = self.tokens[start_idx:end_idx + 1]
num_tokens = len(tokens)
if num_tokens < self.seq_len + 1:
num_pad = (self.seq_len + 1 - num_tokens)
tokens += [self.pad_idx] * num_pad
[tokens, loss_mask, attention_mask, position_ids,
labels] = self._construct_sample(tokens)
if self.overlapping_eval != self.seq_len and idx != 0:
loss_mask[:-self.overlapping_eval] *= 0

return [tokens, loss_mask, attention_mask, position_ids, labels]


class Lambada_Eval_Dataset(paddle.io.Dataset):
def __init__(self, tokens, labels, seq_len, pad_idx):
self.seq_len = seq_len
self.pad_idx = pad_idx
self.tokens = tokens
self.labels = labels

def __len__(self):
return len(self.tokens)

def _construct_sample(self, tokens):
tokens = np.array(tokens).astype("int64").tolist()
labels = tokens[1:]
tokens = tokens[:-1]

seq_length = len(tokens)
# attention mask for the attention calulate
attention_mask = np.tri(seq_length, seq_length).reshape(
(1, seq_length, seq_length))

# the pad and eod tokens do not contribute the loss
position_ids = np.arange(0, seq_length, dtype="int64")

# -INF mask value as default
attention_mask = (attention_mask - 1.0) * 1e9
# Bool mask of attention
attention_mask = attention_mask.astype("float32")
return [tokens, attention_mask, position_ids, labels]

def __getitem__(self, idx):
tokens = self.tokens[idx][:self.seq_len]
labels = self.labels[idx]
tokens = tokens + labels
num_tokens = len(tokens)
if num_tokens < self.seq_len + 1:
num_pad = (self.seq_len + 1 - num_tokens)
tokens += [self.pad_idx] * num_pad
loss_mask = np.zeros(self.seq_len, dtype="float32")
loss_mask[num_tokens - len(labels) - 1:num_tokens - 1] = 1.
[tokens, attention_mask, position_ids, labels] = self._construct_sample(
tokens)
return [tokens, loss_mask, attention_mask, position_ids, labels]


def wikitext_detokenizer(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string


def get_tokens(tokenizer, text, strict=True):
if not strict:
tokens = tokenizer.encode(text)
return tokens[:-1], [tokens[-1]]
last_token = text.split()[-1]
start_idx = text.rfind(last_token)
beginning_tokens = tokenizer.encode(text[:start_idx].strip())
last_token = tokenizer.encode(' ' + last_token)
return beginning_tokens, last_token


def create_eval_dataset(args):
val_dataloader = None
eval_batch_size = args.batch_size
seq_len = args.seq_length

tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
pad_token = tokenizer.command_name_map["pad"].Id

if not args.cloze_eval:
with open(args.eval_path, "rb") as reader:
entire_data = reader.read().decode('utf-8')
num_original_tokens = len(entire_data.strip().split(" "))
entire_data = wikitext_detokenizer(entire_data)
tokenized_data = tokenizer.encode(entire_data)
num_tokenized_tokens = len(tokenized_data)
print('Original Tokens: %d, Detokenized tokens: %d' %
(num_tokenized_tokens, num_original_tokens))
val_dataset = LM_Eval_Dataset(tokenized_data, seq_len, pad_token,
args.overlapping_eval)
else:
tokenized_data = []
tokenized_label = []
with open(args.eval_path, 'r') as f:
for line in f.readlines():
text = json.loads(line)['text']
tokens, labels = get_tokens(tokenizer, text)
tokenized_data.append(tokens)
tokenized_label.append(labels)
val_dataset = Lambada_Eval_Dataset(tokenized_data, tokenized_label,
seq_len, pad_token)
num_tokenized_tokens = 0
num_original_tokens = 0

args.num_examples = len(val_dataset)
args.num_original_tokens = num_original_tokens
args.num_tokenized_tokens = num_tokenized_tokens
val_dataloader = DataLoader(
val_dataset,
batch_size=eval_batch_size,
drop_last=False,
collate_fn=Tuple(Stack(), Stack(), Stack(), Stack(), Stack()))

return val_dataloader


def do_eval(args):
assert args.device in [
"cpu", "gpu", "xpu"
], "Invalid device! Available device should be cpu, gpu, or xpu."
paddle.set_device(args.device)
model_class, tokenizer_class = MODEL_CLASSES[args.model_name_or_path]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

if args.init_checkpoint_path is not None:
model = GPT2ForPretraining(
GPT2Model(**model_class.pretrained_init_configuration[
args.model_name_or_path]))

logger.info("Load model checkpoint from %s" % args.init_checkpoint_path)
model_dict = paddle.load(os.path.join(args.init_checkpoint_path))
model.set_dict(model_dict)
else:
model = model_class.from_pretrained(args.model_name_or_path)

tic_eval = time.time()
eval_data_loader = create_eval_dataset(args)
model.eval()
total_score = 0
score_name = "loss" if not args.cloze_eval else "number correct"
with paddle.no_grad():
for step, batch in enumerate(eval_data_loader):
tokens, loss_mask, attention_mask, position_ids, labels = batch
preds = model(tokens, position_ids, attention_mask)
if not args.cloze_eval:
masked_lm_loss = paddle.nn.functional.cross_entropy(
preds, labels, reduction="none")
loss = paddle.sum(masked_lm_loss * loss_mask)
total_score += loss.numpy() / (args.num_tokenized_tokens - 1)
else:
outputs = paddle.argmax(preds, -1)
acc = paddle.cast(outputs == labels, 'float32')
acc = paddle.where(
paddle.cast(loss_mask, 'bool'), acc, paddle.ones_like(acc))
acc = paddle.sum(paddle.prod(acc, -1))
total_score += acc.numpy()
if step % args.logging_steps == 0:
logger.info("step %d, batch: %d, %s: %f, speed: %.2f step/s" %
(step, step, score_name, total_score,
args.logging_steps / (time.time() - tic_eval)))
tic_eval = time.time()

if not args.cloze_eval:
total_loss = float(total_score)
ppl = math.exp(min(20, total_loss))
token_ratio = (args.num_tokenized_tokens - 1) / (
args.num_original_tokens - 1)
adjusted_ppl = math.exp(min(20, total_loss * token_ratio))
string = ' validation results on {} | '.format(args.eval_path)
string += 'avg loss: {:.4E} | '.format(total_loss)
string += 'ppl: {:.4E} | '.format(ppl)
string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl)
string += 'token ratio: {} |'.format(token_ratio)
else:
num_correct = float(total_score)
acc = float(num_correct / args.num_examples)
string = ' validation results on {} | '.format(args.eval_path)
string += 'number correct: {:.4E} | '.format(num_correct)
string += 'total examples: {:.4E} | '.format(args.num_examples)
string += 'avg accuracy: {:.4E}'.format(acc)
logger.info(string)


if __name__ == "__main__":
args = parser.parse_args()
do_eval(args)
Loading

0 comments on commit 687de91

Please sign in to comment.