Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zero3 checkpoint frozen params #3205

Merged
merged 26 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
81afa0a
zero3 checkpoint frozen params
tjruwase Apr 12, 2023
3de90c3
Merge branch 'master' into olruwase/issue_3090
tjruwase Apr 12, 2023
e672211
Remove debug prints
tjruwase Apr 12, 2023
4c7de69
Merge branch 'olruwase/issue_3090' of github.com:microsoft/DeepSpeed …
tjruwase Apr 12, 2023
29fdbea
Merge branch 'master' into olruwase/issue_3090
tjruwase Apr 13, 2023
378e1ee
Move to cpu
tjruwase Apr 13, 2023
7fbe4bf
Merge branch 'olruwase/issue_3090' of github.com:microsoft/DeepSpeed …
tjruwase Apr 13, 2023
8d2f72b
Merge branch 'master' into olruwase/issue_3090
tjruwase Apr 14, 2023
438f6ff
WIP
tjruwase Apr 17, 2023
96d55a3
Merge branch 'olruwase/issue_3090' of github.com:microsoft/DeepSpeed …
tjruwase Apr 17, 2023
feac428
WIP
tjruwase Apr 18, 2023
fb1a4d5
Merge branch 'master' into olruwase/issue_3090
tjruwase Apr 18, 2023
8dc6e8a
WIP
tjruwase Apr 18, 2023
3ef70ec
Merge branch 'master' into olruwase/issue_3090
tjruwase Apr 18, 2023
5485b10
Cleanup
tjruwase Apr 18, 2023
835f0f2
Merge branch 'olruwase/issue_3090' of github.com:microsoft/DeepSpeed …
tjruwase Apr 18, 2023
71fbd61
Merge branch 'master' into olruwase/issue_3090
tjruwase Apr 18, 2023
be7b54b
Cleanup
tjruwase Apr 18, 2023
7308cb2
Merge branch 'olruwase/issue_3090' of github.com:microsoft/DeepSpeed …
tjruwase Apr 18, 2023
c616875
Extend unit test for frozen params
tjruwase Apr 18, 2023
224c370
API fix
tjruwase Apr 18, 2023
9ba69b0
Merge branch 'master' into olruwase/issue_3090
tjruwase Apr 19, 2023
d9711b8
Merge branch 'master' into olruwase/issue_3090
tjruwase Apr 19, 2023
73227c8
Merge branch 'master' into olruwase/issue_3090
tjruwase Apr 19, 2023
4400d72
Merge branch 'master' into olruwase/issue_3090
tjruwase Apr 20, 2023
dfa3eba
Merge branch 'master' into olruwase/issue_3090
tjruwase Apr 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
PARAM = 'param'
PARAM_SHAPES = 'param_shapes'
BUFFER_NAMES = 'buffer_names'
FROZEN_PARAM_SHAPES = 'frozen_param_shapes'
FROZEN_PARAM_FRAGMENTS = 'frozen_param_fragments'

#########################################
# Checkpoint naming constants
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def _setup_for_real_optimizer(self):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])

# grab the original list
self.bf16_groups.append(param_group['params'])
trainable_parameters = [param for param in param_group['params'] if param.requires_grad]
self.bf16_groups.append(trainable_parameters)

# create flat bf16 params
self.bf16_groups_flat.append(
Expand Down
49 changes: 44 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
WEIGHT_QUANTIZE_ROUNDING, \
WEIGHT_QUANTIZE_VERBOSE, \
WEIGHT_QUANTIZE_KERNEL
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS
from deepspeed.runtime.sparse_tensor import SparseTensor

from deepspeed.runtime import lr_schedules
Expand Down Expand Up @@ -2414,14 +2414,28 @@ def load_moe_state_dict(checkpoint_path,
state_dict.update(expert_state_dict)
moe_layer_id += 1

def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None):
def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None):
module_state_dict = checkpoint['module']
if custom_load_fn:
custom_load_fn(src=state_dict, dst=self.module)
custom_load_fn(src=module_state_dict, dst=self.module)
else:
self.module.load_state_dict(
state_dict, # TODO
module_state_dict, # TODO
strict=strict)

if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None:
saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS]
for param in self.module.parameters():
if param.requires_grad:
continue
if param not in self.param_names:
raise ValueError(f"failed to find frozen {param} in named params")
name = self.param_names[param]
if hasattr(param, 'ds_id'):
param.ds_tensor.data.copy_(saved_frozen_params[name].data)
else:
param.data.copy_(saved_frozen_params[name].data)

def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode):
return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}'

Expand Down Expand Up @@ -2601,7 +2615,7 @@ def _load_checkpoint(self,
num_experts=self.num_experts,
checkpoint_engine=self.checkpoint_engine)
if not self.load_universal_checkpoint():
self.load_module_state_dict(state_dict=checkpoint['module'],
self.load_module_state_dict(checkpoint=checkpoint,
strict=load_module_strict,
custom_load_fn=custom_load_fn)

Expand Down Expand Up @@ -3011,6 +3025,8 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):

zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()

save_frozen_param = self.zero_optimization_partition_gradients()

# A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns None. The module_state_dict() implementation in
Expand All @@ -3023,6 +3039,10 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):
buffer_names=self._get_buffer_names(),
optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None,
param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None,
frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func)
if save_frozen_param else None,
frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func)
if save_frozen_param else None,
lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
data_sampler=self.training_dataloader.data_sampler.state_dict() if
(self.training_dataloader is not None and self.curriculum_learning_enabled()) else None,
Expand Down Expand Up @@ -3062,6 +3082,25 @@ def get_layer_named_buffers(module, prefix=""):

return buffer_names

def _get_param_shape_func(self, param):
return param.ds_shape if hasattr(param, 'ds_id') else param.shape

def _get_param_fragment_func(self, param):
return param.ds_tensor.detach().cpu() if hasattr(param, 'ds_id') else param.detach().cpu()

def _get_zero_frozen_param_attributes(self, attr_func):
frozen_param_fragments = OrderedDict()

for param in self.module.parameters():
if param.requires_grad:
continue
if param not in self.param_names:
raise ValueError(f"failed to find frozen {param} in named params")
name = self.param_names[param]
frozen_param_fragments[name] = attr_func(param)

return frozen_param_fragments

def _get_zero_param_shapes(self):
"""Returns a dict of name to shape mapping, only for the flattened fp32 weights saved by the
optimizer. the names are exactly as in state_dict. The order is absolutely important, since
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,7 @@ def module_state_dict(self):
self.module.save_state_dict(self._curr_ckpt_path, checkpoint_engine=self.checkpoint_engine)
return None

def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None):
def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None):
"""Override hack to instead use a directory path.

This is important because pipeline models checkpoint by layer instead of rank.
Expand All @@ -1260,6 +1260,7 @@ def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None):
strict (bool, optional): Strict state loading. Defaults to True.
"""
assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism"
state_dict = checkpoint['module']
if (state_dict is not None) and (not isinstance(state_dict, str)):
super().load_module_state_dict(state_dict, strict)
return
Expand Down
Loading