diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index c28f4f74b275..702ca71e779d 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -34,8 +34,6 @@ SUB_PARAM_SHAPE, VOCAB_TENSOR, UNIVERSAL_CHECKPOINT_INFO, - UNIVERSAL_CHECKPOINT_VERSION_KEY, - UNIVERSAL_CHECKPOINT_VERSION_VALUE, VOCABULARY_PARAMETER_PATTERNS, PIPELINE_REPLICATED_PARAMETER_PATTERNS, TP_REPLICATED_PARAMETER_PATTERNS, @@ -69,9 +67,6 @@ def parse_arguments(): dest='strict', action='store_false', help='Do not perform validity checks on converted checkpoint.') - parser.add_argument('--inject-missing-state', - action='store_true', - help='Inject missing checkpoint state into the checkpoint if it is absent.') args = parser.parse_args() print(f'args = {args}') return args @@ -452,15 +447,6 @@ def _get_zero_stage(optim_files): return zero_stage -def _inject_missing_state(ds_checkpoint): - if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state: - sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu')) - if UNIVERSAL_CHECKPOINT_INFO not in sd: - ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {} - ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][ - UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE - - def _check_for_required_state(ds_checkpoint): universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.' @@ -476,10 +462,7 @@ def main(args): if zero_stage <= 2: ds_checkpoint = DeepSpeedCheckpoint(args.input_folder) - if args.inject_missing_state: - _inject_missing_state(ds_checkpoint) - else: - _check_for_required_state(ds_checkpoint) + _check_for_required_state(ds_checkpoint) iteration = ds_checkpoint.get_iteration() #_create_latest_file(args.output_folder, iteration)