@@ -505,57 +505,10 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
505505 """
506506 args = get_args ()
507507 load_dir = getattr (args , load_arg )
508-
509- # TODO: remove this redundant code. the tracker is already handled in _load_base_checkpoint
510- # TODO: retire the finetune_from arguments
511- # Determine from which directory we'll try to load
512- # ======
513- if iteration is None :
514- # Read the tracker file and set the iteration.
515- tracker_filename = get_checkpoint_tracker_filename (load_dir )
516-
517- # If we can directly load from load_dir, we resume an experiment
518- if os .path .isfile (tracker_filename ) and load_arg != 'finetune_from' :
519- args .finetune = False
520- print_rank_0 (f"Resuming from { load_dir } " )
521- # Finetuning from a pretrained model
522- elif os .path .isfile (tracker_filename ) and load_arg == 'finetune_from' :
523- assert arg .finetune
524- print_rank_0 (f"Finetuning from { load_dir } " )
525- else :
526- assert not os .path .isfile (tracker_filename )
527- # No tracker file and we are in finetuning, try to load from the `finetune_from` dir
528- if args .finetune :
529- print_rank_0 ('WARNING: could not find the metadata file {} ' .format (
530- tracker_filename ))
531- print_rank_0 (' will try to load from `--finetune-from` instead' )
532- load_dir = getattr (args , 'finetune_from' )
533- tracker_filename = get_checkpoint_tracker_filename (load_dir )
534- # If no tracker file, return iteration zero.
535- if not os .path .isfile (tracker_filename ):
536- print_rank_0 ('WARNING: could not find the metadata file {} ' .format (
537- tracker_filename ))
538- print_rank_0 (' will not load any checkpoints and will start from '
539- 'random' )
540- return 0
541-
542- assert os .path .isfile (tracker_filename )
543-
544- # read the tracker file and either set the iteration or
545- # mark it as a release checkpoint.
546- iteration , release = read_metadata (tracker_filename )
547- else :
548- # Iteration given as argument: do nothing
549- release = False
550- # =======
551508
552509 model = unwrap_model (model )
553510
554- state_dict , release = \
555- _load_base_checkpoint (load_dir ,
556- rank0 = False ,
557- iteration = iteration ,
558- release = release )
511+ state_dict , release = _load_base_checkpoint (load_dir , rank0 = False , iteration = iteration )
559512
560513 # Checkpoint not loaded.
561514 if state_dict is None :
@@ -593,12 +546,11 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
593546 if 'args' in state_dict and not args .finetune :
594547 checkpoint_args = state_dict ['args' ]
595548 check_checkpoint_args (checkpoint_args )
596- if not args .finetune :
597- args .consumed_train_samples = getattr (checkpoint_args ,
598- 'consumed_train_samples' , 0 )
599- update_num_microbatches (consumed_samples = args .consumed_train_samples )
600- args .consumed_valid_samples = getattr (checkpoint_args ,
601- 'consumed_valid_samples' , 0 )
549+ args .consumed_train_samples = getattr (checkpoint_args ,
550+ 'consumed_train_samples' , 0 )
551+ update_num_microbatches (consumed_samples = args .consumed_train_samples )
552+ args .consumed_valid_samples = getattr (checkpoint_args ,
553+ 'consumed_valid_samples' , 0 )
602554 else :
603555 print_rank_0 ('could not find arguments in the checkpoint ...' )
604556
0 commit comments