Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#25 from ZHUI/gpt2/fix_unexpect_stop
Browse files Browse the repository at this point in the history
[GPT-2] bugfix: fix unexpected stop problem when use multi-card to train.
  • Loading branch information
wawltor committed Feb 22, 2021
2 parents 687de91 + f66e275 commit 54a6f7c
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions examples/language_model/gpt2/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def __call__(self, id):


def create_pretrained_dataset(args, input_path, worker_init, worker_index,
eod_id):
worker_num, eod_id):
train_data = GPT2Dataset(
file_path=input_path,
worker_index=worker_index,
num_samples=args.batch_size * args.max_steps,
num_samples=args.batch_size * args.max_steps * worker_num,
eod_id=eod_id,
seed=args.seed + worker_index)
train_batch_sampler = paddle.io.DistributedBatchSampler(
Expand Down Expand Up @@ -144,7 +144,12 @@ def do_train(args):
for f_id in range(num_files):
data_file = files[f_id]
train_data_loader = create_pretrained_dataset(
args, data_file, worker_init, worker_index, eod_id=eod_id)
args,
data_file,
worker_init,
worker_index,
worker_num,
eod_id=eod_id)
for step, batch in enumerate(train_data_loader):
global_step += 1
tokens, loss_mask, attention_mask, position_ids, labels = batch
Expand Down

0 comments on commit 54a6f7c

Please sign in to comment.