@@ -2205,7 +2205,7 @@ def _inner_training_loop(
22052205 max_steps = args .max_steps
22062206 # Setting a very large number of epochs so we go as many times as necessary over the iterator.
22072207 num_train_epochs = sys .maxsize
2208- num_update_steps_per_epoch = max_steps
2208+ num_update_steps_per_epoch = None
22092209 num_examples = total_train_batch_size * args .max_steps
22102210 num_train_samples = args .max_steps * total_train_batch_size
22112211 if args .include_tokens_per_second :
@@ -2355,15 +2355,28 @@ def _inner_training_loop(
23552355 self .state = TrainerState .load_from_json (os .path .join (resume_from_checkpoint , TRAINER_STATE_NAME ))
23562356 self .compare_trainer_and_checkpoint_args (self .args , self .state )
23572357 self ._load_callback_state ()
2358- epochs_trained = int (self .state .global_step // num_update_steps_per_epoch )
2359- if not args .ignore_data_skip :
2360- steps_trained_in_current_epoch = self .state .global_step % (num_update_steps_per_epoch )
2361- steps_trained_in_current_epoch *= args .gradient_accumulation_steps
2358+ if num_update_steps_per_epoch is not None :
2359+ epochs_trained = int (self .state .global_step // num_update_steps_per_epoch )
2360+ if not args .ignore_data_skip :
2361+ steps_trained_in_current_epoch = self .state .global_step % (num_update_steps_per_epoch )
2362+ steps_trained_in_current_epoch *= args .gradient_accumulation_steps
2363+ else :
2364+ steps_trained_in_current_epoch = 0
23622365 else :
2363- steps_trained_in_current_epoch = 0
2366+ # If the dataloader does not have a length, we cannot restore the number of trained epochs.
2367+ # In the following loop, we repeatedly iterate over the dataloader to skip the first
2368+ # `steps_trained_in_current_epoch` steps and increment `epochs_trained` accordingly.
2369+ epochs_trained = 0
2370+ steps_trained_in_current_epoch = self .state .global_step * args .gradient_accumulation_steps
2371+ if args .ignore_data_skip :
2372+ raise ValueError (
2373+ "The dataloader does not have a length, so it is impossible to restore the number of trained"
2374+ " epochs. Please disable the `ignore_data_skip` option."
2375+ )
23642376
23652377 logger .info (" Continuing training from checkpoint, will skip to saved global_step" )
2366- logger .info (f" Continuing training from epoch { epochs_trained } " )
2378+ if num_update_steps_per_epoch is not None :
2379+ logger .info (f" Continuing training from epoch { epochs_trained } " )
23672380 logger .info (f" Continuing training from global step { self .state .global_step } " )
23682381 if not args .ignore_data_skip :
23692382 logger .info (
@@ -2410,6 +2423,26 @@ def _inner_training_loop(
24102423 if hasattr (epoch_dataloader , "set_epoch" ):
24112424 epoch_dataloader .set_epoch (epoch )
24122425
2426+ steps_skipped = 0
2427+ rng_to_sync = False
2428+ if steps_trained_in_current_epoch > 0 and num_update_steps_per_epoch is None :
2429+ # Since the dataloader does not have a length, we just loop until the required number of steps.
2430+ # Every time we reach the end of the dataloader, we increment epoch and reset the iterator.
2431+ epoch_iterator = iter (epoch_iterator )
2432+ epoch_over = False
2433+ while steps_trained_in_current_epoch > 0 :
2434+ try :
2435+ next (epoch_iterator )
2436+ steps_trained_in_current_epoch -= 1
2437+ steps_skipped += 1
2438+ except StopIteration :
2439+ epoch_over = True
2440+ break
2441+ if epoch_over :
2442+ continue
2443+ assert steps_trained_in_current_epoch == 0
2444+ rng_to_sync = True
2445+
24132446 # Reset the past mems state at the beginning of each epoch if necessary.
24142447 if args .past_index >= 0 :
24152448 self ._past = None
@@ -2424,8 +2457,6 @@ def _inner_training_loop(
24242457 if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0 :
24252458 self ._load_rng_state (resume_from_checkpoint )
24262459
2427- rng_to_sync = False
2428- steps_skipped = 0
24292460 if steps_trained_in_current_epoch > 0 :
24302461 epoch_dataloader = skip_first_batches (epoch_dataloader , steps_trained_in_current_epoch )
24312462 steps_skipped = steps_trained_in_current_epoch
@@ -2575,13 +2606,6 @@ def _inner_training_loop(
25752606 if is_torch_xla_available ():
25762607 xm .mark_step ()
25772608 break
2578- if step < 0 :
2579- logger .warning (
2580- "There seems not to be a single sample in your epoch_iterator, stopping training at step"
2581- f" { self .state .global_step } ! This is expected if you're using an IterableDataset and set"
2582- f" num_steps ({ max_steps } ) higher than the number of available samples."
2583- )
2584- self .control .should_training_stop = True
25852609
25862610 self .control = self .callback_handler .on_epoch_end (args , self .state , self .control )
25872611 self ._maybe_log_save_evaluate (tr_loss , grad_norm , model , trial , epoch , ignore_keys_for_eval )
0 commit comments