Skip to content

Commit c81d9e1

Browse files
committed
Correctly support resuming with dataset without length
1 parent 7c6b170 commit c81d9e1

File tree

2 files changed

+86
-17
lines changed

2 files changed

+86
-17
lines changed

src/transformers/trainer.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2423,15 +2423,28 @@ def _inner_training_loop(
24232423
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
24242424
self.compare_trainer_and_checkpoint_args(self.args, self.state)
24252425
self._load_callback_state()
2426-
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
2427-
if not args.ignore_data_skip:
2428-
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
2429-
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
2426+
if num_update_steps_per_epoch is not None:
2427+
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
2428+
if not args.ignore_data_skip:
2429+
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
2430+
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
2431+
else:
2432+
steps_trained_in_current_epoch = 0
24302433
else:
2431-
steps_trained_in_current_epoch = 0
2434+
# If the dataloader does not have a length, we cannot restore the number of trained epochs.
2435+
# In the following loop, we repeatedly iterate over the dataloader to skip the first
2436+
# `steps_trained_in_current_epoch` steps and increment `epochs_trained` accordingly.
2437+
epochs_trained = 0
2438+
steps_trained_in_current_epoch = self.state.global_step * args.gradient_accumulation_steps
2439+
if args.ignore_data_skip:
2440+
raise ValueError(
2441+
"The dataloader does not have a length, so it is impossible to restore the number of trained"
2442+
" epochs. Please disable the `ignore_data_skip` option."
2443+
)
24322444

24332445
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
2434-
logger.info(f" Continuing training from epoch {epochs_trained}")
2446+
if num_update_steps_per_epoch is not None:
2447+
logger.info(f" Continuing training from epoch {epochs_trained}")
24352448
logger.info(f" Continuing training from global step {self.state.global_step}")
24362449
if not args.ignore_data_skip:
24372450
logger.info(
@@ -2464,6 +2477,32 @@ def _inner_training_loop(
24642477
if hasattr(epoch_dataloader, "set_epoch"):
24652478
epoch_dataloader.set_epoch(epoch)
24662479

2480+
steps_skipped = 0
2481+
rng_to_sync = False
2482+
epoch_iterator = None
2483+
if steps_trained_in_current_epoch > 0 and num_update_steps_per_epoch is None:
2484+
# Since the dataloader does not have a length, we just loop until the required number of steps.
2485+
# Every time we reach the end of the dataloader, we increment epoch and reset the iterator.
2486+
epoch_iterator = iter(epoch_dataloader)
2487+
epoch_over = False
2488+
while steps_trained_in_current_epoch > 0:
2489+
try:
2490+
# If the dataloader yields N batches and N is not divisible by `args.gradient_accumulation_steps`,
2491+
# the update loop ignores the last `N % args.gradient_accumulation_steps` batches of an epoch.
2492+
# To replicate the same behavior when resuming training, we ignore such batches from skipped epochs.
2493+
for _ in range(args.gradient_accumulation_steps):
2494+
next(epoch_iterator)
2495+
steps_trained_in_current_epoch -= args.gradient_accumulation_steps
2496+
steps_skipped += args.gradient_accumulation_steps
2497+
except StopIteration:
2498+
epoch_over = True
2499+
break
2500+
if epoch_over:
2501+
epochs_trained += 1
2502+
continue
2503+
assert steps_trained_in_current_epoch == 0
2504+
rng_to_sync = True
2505+
24672506
# Reset the past mems state at the beginning of each epoch if necessary.
24682507
if args.past_index >= 0:
24692508
self._past = None
@@ -2478,16 +2517,15 @@ def _inner_training_loop(
24782517
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
24792518
self._load_rng_state(resume_from_checkpoint)
24802519

2481-
rng_to_sync = False
2482-
steps_skipped = 0
24832520
if steps_trained_in_current_epoch > 0:
24842521
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
24852522
steps_skipped = steps_trained_in_current_epoch
24862523
steps_trained_in_current_epoch = 0
24872524
rng_to_sync = True
24882525

24892526
step = -1
2490-
epoch_iterator = iter(epoch_dataloader)
2527+
if epoch_iterator is None:
2528+
epoch_iterator = iter(epoch_dataloader)
24912529
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
24922530
remainder = steps_in_epoch % args.gradient_accumulation_steps
24932531
if remainder == 0:
@@ -2645,13 +2683,6 @@ def _inner_training_loop(
26452683
if is_torch_xla_available():
26462684
xm.mark_step()
26472685
break
2648-
if step < 0:
2649-
logger.warning(
2650-
"There seems not to be a single sample in your epoch_iterator, stopping training at step"
2651-
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
2652-
f" num_steps ({max_steps}) higher than the number of available samples."
2653-
)
2654-
self.control.should_training_stop = True
26552686

26562687
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
26572688
self._maybe_log_save_evaluate(
@@ -5376,7 +5407,7 @@ def set_initial_training_values(
53765407
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
53775408
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
53785409
num_train_epochs = sys.maxsize
5379-
num_update_steps_per_epoch = max_steps
5410+
num_update_steps_per_epoch = None
53805411
num_examples = total_train_batch_size * args.max_steps
53815412
num_train_samples = args.max_steps * total_train_batch_size
53825413
else:

tests/trainer/test_trainer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3378,6 +3378,44 @@ def test_resume_training_with_frozen_params(self):
33783378
self.assertEqual(b, b1)
33793379
self.check_trainer_state_are_the_same(state, state1)
33803380

3381+
@parameterized.expand([(9, 1), (10, 1), (11, 1), (20, 1), (21, 1), (9, 3), (9, 2)])
3382+
def test_resume_training_with_iterable_dataset(self, dataset_length, gradient_accumulation_steps):
3383+
with tempfile.TemporaryDirectory() as tmpdir:
3384+
3385+
def get_trainer():
3386+
config = RegressionModelConfig()
3387+
train_dataset = SampleIterableDataset(length=dataset_length)
3388+
model = RegressionRandomPreTrainedModel(config)
3389+
args = RegressionTrainingArguments(
3390+
output_dir=tmpdir,
3391+
learning_rate=0.1,
3392+
max_steps=20,
3393+
save_steps=10,
3394+
per_device_train_batch_size=1,
3395+
gradient_accumulation_steps=gradient_accumulation_steps,
3396+
)
3397+
return Trainer(model=model, args=args, train_dataset=train_dataset)
3398+
3399+
# Train from scratch.
3400+
trainer = get_trainer()
3401+
trainer.train()
3402+
self.assertEqual(trainer.state.global_step, 20)
3403+
(a, b) = trainer.model.a.item(), trainer.model.b.item()
3404+
state = dataclasses.asdict(trainer.state)
3405+
3406+
# Train from a checkpoint.
3407+
checkpoint = os.path.join(tmpdir, "checkpoint-10")
3408+
trainer = get_trainer()
3409+
trainer.train(resume_from_checkpoint=checkpoint)
3410+
self.assertEqual(trainer.state.global_step, 20)
3411+
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
3412+
state1 = dataclasses.asdict(trainer.state)
3413+
3414+
# Check that the resumed model is the same as the original one.
3415+
self.assertEqual(a, a1)
3416+
self.assertEqual(b, b1)
3417+
self.check_trainer_state_are_the_same(state, state1)
3418+
33813419
def test_load_best_model_at_end(self):
33823420
total = int(self.n_epochs * 64 / self.batch_size)
33833421
with tempfile.TemporaryDirectory() as tmpdir:

0 commit comments

Comments
 (0)