Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create mininal universal checkpoint info for client state #5526

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion deepspeed/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import os
import torch
from .constants import (MODEL_FILE_PREFIX, MODEL_FILE_SUFFIX, OPTIM_FILE_SUFFIX, ZERO_FILE_PREFIX)
from .constants import (MODEL_FILE_PREFIX, MODEL_FILE_SUFFIX, OPTIM_FILE_SUFFIX, ZERO_FILE_PREFIX, UNIVERSAL_CHECKPOINT_INFO, UNIVERSAL_CHECKPOINT_VERSION_KEY, UNIVERSAL_CHECKPOINT_VERSION_VALUE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @xylian86 - can you run the precommit formatter on this branch so it will pass our Formatting check?

pre-commit run --all-files



def get_model_ckpt_name_for_rank(base_folder, mp_rank_str):
Expand Down Expand Up @@ -60,3 +60,16 @@ def clone_tensors_for_torch_save(item, device=torch.device('cpu')):
return type(item)({k: clone_tensors_for_torch_save(v, device) for k, v in item.items()})
else:
return item


def inject_universal_info(state_dict):
"""
Ensure the universal checkpoint information is present in the config dictionary.
Adds a version key if it doesn't exist.

Args:
config_dict (dict): The dictionary to inject universal checkpoint information into.
"""
if UNIVERSAL_CHECKPOINT_INFO not in state_dict:
state_dict[UNIVERSAL_CHECKPOINT_INFO] = {}
state_dict[UNIVERSAL_CHECKPOINT_INFO][UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE
2 changes: 2 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
WEIGHT_QUANTIZE_VERBOSE, \
WEIGHT_QUANTIZE_KERNEL
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS
from deepspeed.checkpoint.utils import inject_universal_info
from deepspeed.runtime.sparse_tensor import SparseTensor

from deepspeed.runtime import lr_schedules
Expand Down Expand Up @@ -3319,6 +3320,7 @@ def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parame
ds_config=self.config,
ds_version=version)
state.update(client_state)
inject_universal_info(state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shoudn't we show a warning when we don't have the necessary info?
It will silently produce an incorrect checkpoint if the checkpoint is loaded for TP or PP.
We can say that the converted checkpoint is only for pure DP scaling.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After I discussed with Tunji, we are considering another approach for this. He will share a new approach. I keep this comment but just disregard it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xylian86, the new approach is to do the injection in the conversion script rather than during saving. The injection should be done into ds_checkpoint before this assertion. Furthermore, the injection should be enabled by command-line argument (disabled by default) so that users are fully aware of what is going on. The command-line arg could be called --inject-missing-state.


if self.save_non_zero_checkpoint:
log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
Expand Down
Loading