@@ -2423,15 +2423,28 @@ def _inner_training_loop(
24232423 self .state = TrainerState .load_from_json (os .path .join (resume_from_checkpoint , TRAINER_STATE_NAME ))
24242424 self .compare_trainer_and_checkpoint_args (self .args , self .state )
24252425 self ._load_callback_state ()
2426- epochs_trained = int (self .state .global_step // num_update_steps_per_epoch )
2427- if not args .ignore_data_skip :
2428- steps_trained_in_current_epoch = self .state .global_step % (num_update_steps_per_epoch )
2429- steps_trained_in_current_epoch *= args .gradient_accumulation_steps
2426+ if num_update_steps_per_epoch is not None :
2427+ epochs_trained = int (self .state .global_step // num_update_steps_per_epoch )
2428+ if not args .ignore_data_skip :
2429+ steps_trained_in_current_epoch = self .state .global_step % (num_update_steps_per_epoch )
2430+ steps_trained_in_current_epoch *= args .gradient_accumulation_steps
2431+ else :
2432+ steps_trained_in_current_epoch = 0
24302433 else :
2431- steps_trained_in_current_epoch = 0
2434+ # If the dataloader does not have a length, we cannot restore the number of trained epochs.
2435+ # In the following loop, we repeatedly iterate over the dataloader to skip the first
2436+ # `steps_trained_in_current_epoch` steps and increment `epochs_trained` accordingly.
2437+ epochs_trained = 0
2438+ steps_trained_in_current_epoch = self .state .global_step * args .gradient_accumulation_steps
2439+ if args .ignore_data_skip :
2440+ raise ValueError (
2441+ "The dataloader does not have a length, so it is impossible to restore the number of trained"
2442+ " epochs. Please disable the `ignore_data_skip` option."
2443+ )
24322444
24332445 logger .info (" Continuing training from checkpoint, will skip to saved global_step" )
2434- logger .info (f" Continuing training from epoch { epochs_trained } " )
2446+ if num_update_steps_per_epoch is not None :
2447+ logger .info (f" Continuing training from epoch { epochs_trained } " )
24352448 logger .info (f" Continuing training from global step { self .state .global_step } " )
24362449 if not args .ignore_data_skip :
24372450 logger .info (
@@ -2464,6 +2477,32 @@ def _inner_training_loop(
24642477 if hasattr (epoch_dataloader , "set_epoch" ):
24652478 epoch_dataloader .set_epoch (epoch )
24662479
2480+ steps_skipped = 0
2481+ rng_to_sync = False
2482+ epoch_iterator = None
2483+ if steps_trained_in_current_epoch > 0 and num_update_steps_per_epoch is None :
2484+ # Since the dataloader does not have a length, we just loop until the required number of steps.
2485+ # Every time we reach the end of the dataloader, we increment epoch and reset the iterator.
2486+ epoch_iterator = iter (epoch_dataloader )
2487+ epoch_over = False
2488+ while steps_trained_in_current_epoch > 0 :
2489+ try :
2490+ # If the dataloader yields N batches and N is not divisible by `args.gradient_accumulation_steps`,
2491+ # the update loop ignores the last `N % args.gradient_accumulation_steps` batches of an epoch.
2492+ # To replicate the same behavior when resuming training, we ignore such batches from skipped epochs.
2493+ for _ in range (args .gradient_accumulation_steps ):
2494+ next (epoch_iterator )
2495+ steps_trained_in_current_epoch -= args .gradient_accumulation_steps
2496+ steps_skipped += args .gradient_accumulation_steps
2497+ except StopIteration :
2498+ epoch_over = True
2499+ break
2500+ if epoch_over :
2501+ epochs_trained += 1
2502+ continue
2503+ assert steps_trained_in_current_epoch == 0
2504+ rng_to_sync = True
2505+
24672506 # Reset the past mems state at the beginning of each epoch if necessary.
24682507 if args .past_index >= 0 :
24692508 self ._past = None
@@ -2478,16 +2517,15 @@ def _inner_training_loop(
24782517 if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0 :
24792518 self ._load_rng_state (resume_from_checkpoint )
24802519
2481- rng_to_sync = False
2482- steps_skipped = 0
24832520 if steps_trained_in_current_epoch > 0 :
24842521 epoch_dataloader = skip_first_batches (epoch_dataloader , steps_trained_in_current_epoch )
24852522 steps_skipped = steps_trained_in_current_epoch
24862523 steps_trained_in_current_epoch = 0
24872524 rng_to_sync = True
24882525
24892526 step = - 1
2490- epoch_iterator = iter (epoch_dataloader )
2527+ if epoch_iterator is None :
2528+ epoch_iterator = iter (epoch_dataloader )
24912529 # We chunkify the epoch iterator into gradient accumulation steps `n` batches
24922530 remainder = steps_in_epoch % args .gradient_accumulation_steps
24932531 if remainder == 0 :
@@ -2645,13 +2683,6 @@ def _inner_training_loop(
26452683 if is_torch_xla_available ():
26462684 xm .mark_step ()
26472685 break
2648- if step < 0 :
2649- logger .warning (
2650- "There seems not to be a single sample in your epoch_iterator, stopping training at step"
2651- f" { self .state .global_step } ! This is expected if you're using an IterableDataset and set"
2652- f" num_steps ({ max_steps } ) higher than the number of available samples."
2653- )
2654- self .control .should_training_stop = True
26552686
26562687 self .control = self .callback_handler .on_epoch_end (args , self .state , self .control )
26572688 self ._maybe_log_save_evaluate (
@@ -5348,7 +5379,7 @@ def set_initial_training_values(
53485379 elif args .max_steps > 0 : # Rely on max_steps when dataloader does not have a working size
53495380 # Setting a very large number of epochs so we go as many times as necessary over the iterator.
53505381 num_train_epochs = sys .maxsize
5351- num_update_steps_per_epoch = max_steps
5382+ num_update_steps_per_epoch = None
53525383 num_examples = total_train_batch_size * args .max_steps
53535384 num_train_samples = args .max_steps * total_train_batch_size
53545385 else :
0 commit comments