Skip to content

Commit bc56f1c

Browse files
committed
Correctly support resuming with dataset without length
1 parent 405b562 commit bc56f1c

File tree

2 files changed

+78
-16
lines changed

2 files changed

+78
-16
lines changed

src/transformers/trainer.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,7 +2205,7 @@ def _inner_training_loop(
22052205
max_steps = args.max_steps
22062206
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
22072207
num_train_epochs = sys.maxsize
2208-
num_update_steps_per_epoch = max_steps
2208+
num_update_steps_per_epoch = None
22092209
num_examples = total_train_batch_size * args.max_steps
22102210
num_train_samples = args.max_steps * total_train_batch_size
22112211
if args.include_tokens_per_second:
@@ -2355,15 +2355,28 @@ def _inner_training_loop(
23552355
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
23562356
self.compare_trainer_and_checkpoint_args(self.args, self.state)
23572357
self._load_callback_state()
2358-
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
2359-
if not args.ignore_data_skip:
2360-
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
2361-
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
2358+
if num_update_steps_per_epoch is not None:
2359+
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
2360+
if not args.ignore_data_skip:
2361+
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
2362+
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
2363+
else:
2364+
steps_trained_in_current_epoch = 0
23622365
else:
2363-
steps_trained_in_current_epoch = 0
2366+
# If the dataloader does not have a length, we cannot restore the number of trained epochs.
2367+
# In the following loop, we repeatedly iterate over the dataloader to skip the first
2368+
# `steps_trained_in_current_epoch` steps and increment `epochs_trained` accordingly.
2369+
epochs_trained = 0
2370+
steps_trained_in_current_epoch = self.state.global_step * args.gradient_accumulation_steps
2371+
if args.ignore_data_skip:
2372+
raise ValueError(
2373+
"The dataloader does not have a length, so it is impossible to restore the number of trained"
2374+
" epochs. Please disable the `ignore_data_skip` option."
2375+
)
23642376

23652377
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
2366-
logger.info(f" Continuing training from epoch {epochs_trained}")
2378+
if num_update_steps_per_epoch is not None:
2379+
logger.info(f" Continuing training from epoch {epochs_trained}")
23672380
logger.info(f" Continuing training from global step {self.state.global_step}")
23682381
if not args.ignore_data_skip:
23692382
logger.info(
@@ -2410,6 +2423,26 @@ def _inner_training_loop(
24102423
if hasattr(epoch_dataloader, "set_epoch"):
24112424
epoch_dataloader.set_epoch(epoch)
24122425

2426+
steps_skipped = 0
2427+
rng_to_sync = False
2428+
if steps_trained_in_current_epoch > 0 and num_update_steps_per_epoch is None:
2429+
# Since the dataloader does not have a length, we just loop until the required number of steps.
2430+
# Every time we reach the end of the dataloader, we increment epoch and reset the iterator.
2431+
epoch_iterator = iter(epoch_iterator)
2432+
epoch_over = False
2433+
while steps_trained_in_current_epoch > 0:
2434+
try:
2435+
next(epoch_iterator)
2436+
steps_trained_in_current_epoch -= 1
2437+
steps_skipped += 1
2438+
except StopIteration:
2439+
epoch_over = True
2440+
break
2441+
if epoch_over:
2442+
continue
2443+
assert steps_trained_in_current_epoch == 0
2444+
rng_to_sync = True
2445+
24132446
# Reset the past mems state at the beginning of each epoch if necessary.
24142447
if args.past_index >= 0:
24152448
self._past = None
@@ -2424,8 +2457,6 @@ def _inner_training_loop(
24242457
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
24252458
self._load_rng_state(resume_from_checkpoint)
24262459

2427-
rng_to_sync = False
2428-
steps_skipped = 0
24292460
if steps_trained_in_current_epoch > 0:
24302461
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
24312462
steps_skipped = steps_trained_in_current_epoch
@@ -2575,13 +2606,6 @@ def _inner_training_loop(
25752606
if is_torch_xla_available():
25762607
xm.mark_step()
25772608
break
2578-
if step < 0:
2579-
logger.warning(
2580-
"There seems not to be a single sample in your epoch_iterator, stopping training at step"
2581-
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
2582-
f" num_steps ({max_steps}) higher than the number of available samples."
2583-
)
2584-
self.control.should_training_stop = True
25852609

25862610
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
25872611
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)

tests/trainer/test_trainer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2958,6 +2958,44 @@ def test_resume_training_with_frozen_params(self):
29582958
self.assertEqual(b, b1)
29592959
self.check_trainer_state_are_the_same(state, state1)
29602960

2961+
@parameterized.expand([(9, 1), (10, 1), (11, 1), (20, 1), (21, 1), (9, 2)])
2962+
def test_resume_training_with_iterable_dataset(self, dataset_length, gradient_accumulation_steps):
2963+
with tempfile.TemporaryDirectory() as tmpdir:
2964+
2965+
def get_trainer():
2966+
config = RegressionModelConfig()
2967+
train_dataset = SampleIterableDataset(length=dataset_length)
2968+
model = RegressionRandomPreTrainedModel(config)
2969+
args = RegressionTrainingArguments(
2970+
output_dir=tmpdir,
2971+
learning_rate=0.1,
2972+
max_steps=20,
2973+
save_steps=10,
2974+
per_device_train_batch_size=1,
2975+
gradient_accumulation_steps=gradient_accumulation_steps,
2976+
)
2977+
return Trainer(model=model, args=args, train_dataset=train_dataset)
2978+
2979+
# Train from scratch.
2980+
trainer = get_trainer()
2981+
trainer.train()
2982+
self.assertEqual(trainer.state.global_step, 20)
2983+
(a, b) = trainer.model.a.item(), trainer.model.b.item()
2984+
state = dataclasses.asdict(trainer.state)
2985+
2986+
# Train from a checkpoint.
2987+
checkpoint = os.path.join(tmpdir, "checkpoint-10")
2988+
trainer = get_trainer()
2989+
trainer.train(resume_from_checkpoint=checkpoint)
2990+
self.assertEqual(trainer.state.global_step, 20)
2991+
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
2992+
state1 = dataclasses.asdict(trainer.state)
2993+
2994+
# Check that the resumed model is the same as the original one.
2995+
self.assertEqual(a, a1)
2996+
self.assertEqual(b, b1)
2997+
self.check_trainer_state_are_the_same(state, state1)
2998+
29612999
def test_load_best_model_at_end(self):
29623000
total = int(self.n_epochs * 64 / self.batch_size)
29633001
with tempfile.TemporaryDirectory() as tmpdir:

0 commit comments

Comments
 (0)