Skip to content

Commit 3dbd929

Browse files
committed
remove --finetune-from argument to make checkpoint loading logic simpler
1 parent 37353b1 commit 3dbd929

File tree

2 files changed

+6
-57
lines changed

2 files changed

+6
-57
lines changed

megatron/arguments.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -959,9 +959,6 @@ def _add_checkpointing_args(parser):
959959
group.add_argument('--use-checkpoint-args', action='store_true',
960960
help='Override any command line arguments with arguments '
961961
'from the checkpoint')
962-
group.add_argument('--finetune-from', type=str, default=None,
963-
help='Directory containing a model checkpoint for finetuning.'
964-
'Will be loaded if the `--load` directory contains no checkpoint')
965962
group.add_argument('--exit-on-missing-checkpoint', action='store_true',
966963
help="If '--load' is set, but checkpoint is not found "
967964
"(e.g., path typo), then exit instead of random "

megatron/checkpointing.py

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -505,57 +505,10 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
505505
"""
506506
args = get_args()
507507
load_dir = getattr(args, load_arg)
508-
509-
# TODO: remove this redundant code. the tracker is already handled in _load_base_checkpoint
510-
# TODO: retire the finetune_from arguments
511-
# Determine from which directory we'll try to load
512-
# ======
513-
if iteration is None:
514-
# Read the tracker file and set the iteration.
515-
tracker_filename = get_checkpoint_tracker_filename(load_dir)
516-
517-
# If we can directly load from load_dir, we resume an experiment
518-
if os.path.isfile(tracker_filename) and load_arg != 'finetune_from':
519-
args.finetune=False
520-
print_rank_0(f"Resuming from {load_dir}")
521-
# Finetuning from a pretrained model
522-
elif os.path.isfile(tracker_filename) and load_arg == 'finetune_from':
523-
assert arg.finetune
524-
print_rank_0(f"Finetuning from {load_dir}")
525-
else:
526-
assert not os.path.isfile(tracker_filename)
527-
# No tracker file and we are in finetuning, try to load from the `finetune_from` dir
528-
if args.finetune:
529-
print_rank_0('WARNING: could not find the metadata file {} '.format(
530-
tracker_filename))
531-
print_rank_0(' will try to load from `--finetune-from` instead')
532-
load_dir = getattr(args, 'finetune_from')
533-
tracker_filename = get_checkpoint_tracker_filename(load_dir)
534-
# If no tracker file, return iteration zero.
535-
if not os.path.isfile(tracker_filename):
536-
print_rank_0('WARNING: could not find the metadata file {} '.format(
537-
tracker_filename))
538-
print_rank_0(' will not load any checkpoints and will start from '
539-
'random')
540-
return 0
541-
542-
assert os.path.isfile(tracker_filename)
543-
544-
# read the tracker file and either set the iteration or
545-
# mark it as a release checkpoint.
546-
iteration, release = read_metadata(tracker_filename)
547-
else:
548-
# Iteration given as argument: do nothing
549-
release = False
550-
# =======
551508

552509
model = unwrap_model(model)
553510

554-
state_dict, release = \
555-
_load_base_checkpoint(load_dir,
556-
rank0=False,
557-
iteration=iteration,
558-
release=release)
511+
state_dict, release = _load_base_checkpoint(load_dir, rank0=False, iteration=iteration)
559512

560513
# Checkpoint not loaded.
561514
if state_dict is None:
@@ -593,12 +546,11 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
593546
if 'args' in state_dict and not args.finetune:
594547
checkpoint_args = state_dict['args']
595548
check_checkpoint_args(checkpoint_args)
596-
if not args.finetune:
597-
args.consumed_train_samples = getattr(checkpoint_args,
598-
'consumed_train_samples', 0)
599-
update_num_microbatches(consumed_samples=args.consumed_train_samples)
600-
args.consumed_valid_samples = getattr(checkpoint_args,
601-
'consumed_valid_samples', 0)
549+
args.consumed_train_samples = getattr(checkpoint_args,
550+
'consumed_train_samples', 0)
551+
update_num_microbatches(consumed_samples=args.consumed_train_samples)
552+
args.consumed_valid_samples = getattr(checkpoint_args,
553+
'consumed_valid_samples', 0)
602554
else:
603555
print_rank_0('could not find arguments in the checkpoint ...')
604556

0 commit comments

Comments
 (0)