Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an argument to enable the injection of missing state during the conversion of universal checkpoints #5608

Merged
merged 10 commits into from
Jun 27, 2024
19 changes: 18 additions & 1 deletion deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
SUB_PARAM_SHAPE,
VOCAB_TENSOR,
UNIVERSAL_CHECKPOINT_INFO,
UNIVERSAL_CHECKPOINT_VERSION_KEY,
UNIVERSAL_CHECKPOINT_VERSION_VALUE,
VOCABULARY_PARAMETER_PATTERNS,
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
TP_REPLICATED_PARAMETER_PATTERNS,
Expand Down Expand Up @@ -67,6 +69,9 @@ def parse_arguments():
dest='strict',
action='store_false',
help='Do not perform validity checks on converted checkpoint.')
parser.add_argument('--inject-missing-state',
action='store_true',
help='Inject missing checkpoint state into the checkpoint if it is absent.')
args = parser.parse_args()
print(f'args = {args}')
return args
Expand Down Expand Up @@ -447,6 +452,15 @@ def _get_zero_stage(optim_files):
return zero_stage


def _inject_missing_state(ds_checkpoint):
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'))
if UNIVERSAL_CHECKPOINT_INFO not in sd:
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][
UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE


def _check_for_required_state(ds_checkpoint):
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'
Expand All @@ -462,7 +476,10 @@ def main(args):

if zero_stage <= 2:
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
_check_for_required_state(ds_checkpoint)
if args.inject_missing_state:
_inject_missing_state(ds_checkpoint)
else:
_check_for_required_state(ds_checkpoint)

iteration = ds_checkpoint.get_iteration()
#_create_latest_file(args.output_folder, iteration)
Expand Down
Loading