From f66e27590d579f0fab31e8695a7056eb479ab6ab Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Mon, 22 Feb 2021 14:38:36 +0800 Subject: [PATCH] [GPT-2] bugfix: fix unexpected stop problem when use multi-card to train. --- examples/language_model/gpt2/run_pretrain.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/language_model/gpt2/run_pretrain.py b/examples/language_model/gpt2/run_pretrain.py index 087b8e46bee826..be2d8c52929761 100644 --- a/examples/language_model/gpt2/run_pretrain.py +++ b/examples/language_model/gpt2/run_pretrain.py @@ -46,11 +46,11 @@ def __call__(self, id): def create_pretrained_dataset(args, input_path, worker_init, worker_index, - eod_id): + worker_num, eod_id): train_data = GPT2Dataset( file_path=input_path, worker_index=worker_index, - num_samples=args.batch_size * args.max_steps, + num_samples=args.batch_size * args.max_steps * worker_num, eod_id=eod_id, seed=args.seed + worker_index) train_batch_sampler = paddle.io.DistributedBatchSampler( @@ -144,7 +144,12 @@ def do_train(args): for f_id in range(num_files): data_file = files[f_id] train_data_loader = create_pretrained_dataset( - args, data_file, worker_init, worker_index, eod_id=eod_id) + args, + data_file, + worker_init, + worker_index, + worker_num, + eod_id=eod_id) for step, batch in enumerate(train_data_loader): global_step += 1 tokens, loss_mask, attention_mask, position_ids, labels = batch