diff --git a/examples/language_model/gpt2/run_pretrain.py b/examples/language_model/gpt2/run_pretrain.py index 087b8e46bee82..be2d8c5292976 100644 --- a/examples/language_model/gpt2/run_pretrain.py +++ b/examples/language_model/gpt2/run_pretrain.py @@ -46,11 +46,11 @@ def __call__(self, id): def create_pretrained_dataset(args, input_path, worker_init, worker_index, - eod_id): + worker_num, eod_id): train_data = GPT2Dataset( file_path=input_path, worker_index=worker_index, - num_samples=args.batch_size * args.max_steps, + num_samples=args.batch_size * args.max_steps * worker_num, eod_id=eod_id, seed=args.seed + worker_index) train_batch_sampler = paddle.io.DistributedBatchSampler( @@ -144,7 +144,12 @@ def do_train(args): for f_id in range(num_files): data_file = files[f_id] train_data_loader = create_pretrained_dataset( - args, data_file, worker_init, worker_index, eod_id=eod_id) + args, + data_file, + worker_init, + worker_index, + worker_num, + eod_id=eod_id) for step, batch in enumerate(train_data_loader): global_step += 1 tokens, loss_mask, attention_mask, position_ids, labels = batch