Skip to content

Commit

Permalink
Add an argument to enable the injection of missing state during the c…
Browse files Browse the repository at this point in the history
…onversion of universal checkpoints (#5608)

This PR solves the
[Issue-5430](#5430).

The PR enables the universal checkpoint feature for other platforms like
HuggingFace Trainer without requiring changes to the HuggingFace code.
It does this by adding an argument that allows the injection of minimal
necessary information into the state before this
[assertion](https://github.com/microsoft/DeepSpeed/blob/ebf82e8f3ad6d51d49d115e54a11ae4597ff36fb/deepspeed/checkpoint/ds_to_universal.py#L358).

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Abhishek Kulkarni <abkulkarni@microsoft.com>
  • Loading branch information
5 people authored Jun 27, 2024
1 parent b421e8c commit f0e3f01
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
SUB_PARAM_SHAPE,
VOCAB_TENSOR,
UNIVERSAL_CHECKPOINT_INFO,
UNIVERSAL_CHECKPOINT_VERSION_KEY,
UNIVERSAL_CHECKPOINT_VERSION_VALUE,
VOCABULARY_PARAMETER_PATTERNS,
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
TP_REPLICATED_PARAMETER_PATTERNS,
Expand Down Expand Up @@ -67,6 +69,9 @@ def parse_arguments():
dest='strict',
action='store_false',
help='Do not perform validity checks on converted checkpoint.')
parser.add_argument('--inject-missing-state',
action='store_true',
help='Inject missing checkpoint state into the checkpoint if it is absent.')
args = parser.parse_args()
print(f'args = {args}')
return args
Expand Down Expand Up @@ -447,6 +452,15 @@ def _get_zero_stage(optim_files):
return zero_stage


def _inject_missing_state(ds_checkpoint):
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'))
if UNIVERSAL_CHECKPOINT_INFO not in sd:
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][
UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE


def _check_for_required_state(ds_checkpoint):
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'
Expand All @@ -462,7 +476,10 @@ def main(args):

if zero_stage <= 2:
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
_check_for_required_state(ds_checkpoint)
if args.inject_missing_state:
_inject_missing_state(ds_checkpoint)
else:
_check_for_required_state(ds_checkpoint)

iteration = ds_checkpoint.get_iteration()
#_create_latest_file(args.output_folder, iteration)
Expand Down

0 comments on commit f0e3f01

Please sign in to comment.