Skip to content

Commit

Permalink
Revert "Add an argument to enable the injection of missing state duri…
Browse files Browse the repository at this point in the history
…ng the conversion of universal checkpoints (#5608)"

This reverts commit f0e3f01.
  • Loading branch information
loadams committed Jul 2, 2024
1 parent 3d34727 commit 3087dd0
Showing 1 changed file with 1 addition and 18 deletions.
19 changes: 1 addition & 18 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
SUB_PARAM_SHAPE,
VOCAB_TENSOR,
UNIVERSAL_CHECKPOINT_INFO,
UNIVERSAL_CHECKPOINT_VERSION_KEY,
UNIVERSAL_CHECKPOINT_VERSION_VALUE,
VOCABULARY_PARAMETER_PATTERNS,
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
TP_REPLICATED_PARAMETER_PATTERNS,
Expand Down Expand Up @@ -69,9 +67,6 @@ def parse_arguments():
dest='strict',
action='store_false',
help='Do not perform validity checks on converted checkpoint.')
parser.add_argument('--inject-missing-state',
action='store_true',
help='Inject missing checkpoint state into the checkpoint if it is absent.')
args = parser.parse_args()
print(f'args = {args}')
return args
Expand Down Expand Up @@ -452,15 +447,6 @@ def _get_zero_stage(optim_files):
return zero_stage


def _inject_missing_state(ds_checkpoint):
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'))
if UNIVERSAL_CHECKPOINT_INFO not in sd:
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][
UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE


def _check_for_required_state(ds_checkpoint):
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'
Expand All @@ -476,10 +462,7 @@ def main(args):

if zero_stage <= 2:
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
if args.inject_missing_state:
_inject_missing_state(ds_checkpoint)
else:
_check_for_required_state(ds_checkpoint)
_check_for_required_state(ds_checkpoint)

iteration = ds_checkpoint.get_iteration()
#_create_latest_file(args.output_folder, iteration)
Expand Down

0 comments on commit 3087dd0

Please sign in to comment.