Skip to content

Commit

Permalink
Option to exclude frozen weights for checkpoint save (deepspeedai#3953)
Browse files Browse the repository at this point in the history
* Option to exclude frozen weights for checkpoint save

* Extend unit test

* Support PP training
  • Loading branch information
tjruwase authored Jul 20, 2023
1 parent ceccfa3 commit 0a0819b
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 20 deletions.
33 changes: 24 additions & 9 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2439,8 +2439,15 @@ def all_gather_scalar(self, value, dp_group):
dist.all_gather(tensor_list, value, group=dp_group)
return tensor_list

def module_state_dict(self, destination=None, prefix="", keep_vars=False):
def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False):
sd = self.module.state_dict(destination, prefix, keep_vars)

# Remove frozen parameter weights from state_dict if specified
if exclude_frozen_parameters:
for n, p in self.module.named_parameters():
if not p.requires_grad:
del sd[n]

if self.random_ltd_enabled():
sd = remove_random_ltd_state_dict(sd)
return sd
Expand Down Expand Up @@ -2896,7 +2903,7 @@ def _checkpoint_tag_validation(self, tag):
elif not valid:
logger.warning(msg)

def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True):
def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False):
"""Save training checkpoint
Arguments:
Expand All @@ -2905,6 +2912,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)
used if not provided. Tag name must be the same across all ranks.
client_state: Optional. State dictionary used for saving required training states in the client code.
save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint.
exclude_frozen_parameters: Optional. Exclude frozen parameters from checkpointed state.
Important: all processes must call this method and not just the process with rank 0. It is
because each process needs to save its master weights and scheduler+optimizer states. This
method will hang waiting to synchronize with other processes if it's called just for the
Expand Down Expand Up @@ -2937,15 +2945,21 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)
if self.has_moe_layers:
self.save_non_zero_checkpoint = False
self._create_checkpoint_file(save_dir, tag, False)
self._save_moe_checkpoint(save_dir, tag, client_state=client_state)
self._save_moe_checkpoint(save_dir,
tag,
client_state=client_state,
exclude_frozen_parameters=exclude_frozen_parameters)

# We distribute the task of saving layer checkpoint files among
# data parallel instances, so all procs should call _save_checkpoint.
# All procs then call module_state_dict(), but only procs of data
# parallel rank 0 save the general model params.
if not self.has_moe_layers:
self._create_checkpoint_file(save_dir, tag, False)
self._save_checkpoint(save_dir, tag, client_state=client_state)
self._save_checkpoint(save_dir,
tag,
client_state=client_state,
exclude_frozen_parameters=exclude_frozen_parameters)

if self.save_zero_checkpoint:
self._create_zero_checkpoint_files(save_dir, tag)
Expand Down Expand Up @@ -2974,7 +2988,7 @@ def _get_non_moe_state_dict(self, full_state_dict):

return full_state_dict

def _save_moe_checkpoint(self, save_dir, tag, client_state={}):
def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False):
save_path = self._get_ckpt_name(save_dir, tag)
# A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
Expand Down Expand Up @@ -3049,7 +3063,8 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}):
self.checkpoint_engine.save(optimizer_state, file_path)

# get non-moe parameters
model_state_dict = self._get_non_moe_state_dict(self.module_state_dict())
model_state_dict = self._get_non_moe_state_dict(
self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters))

if expp_rank == 0:
# TODO: update num experts info,.. in checkpoint
Expand Down Expand Up @@ -3106,20 +3121,20 @@ def _create_zero_checkpoint_files(self, save_dir, tag):

return success

def _save_checkpoint(self, save_dir, tag, client_state={}):
def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False):

save_path = self._get_ckpt_name(save_dir, tag)

zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()

save_frozen_param = self.zero_optimization_partition_gradients()
save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters

# A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns None. The module_state_dict() implementation in
# PipelineEngine expects the save path to be set in self._curr_ckpt_path.
self._curr_ckpt_path = os.path.join(save_dir, tag)
module = self.module_state_dict()
module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters)
self._curr_ckpt_path = None

state = dict(module=module,
Expand Down
6 changes: 4 additions & 2 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ def mem_status(self, msg, print_rank=-1, reset_max=False):
f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)')

def module_state_dict(self):
def module_state_dict(self, exclude_frozen_parameters=False):
"""Override hack to save a pipe model and return the directory path of the save.
This method should only be called by DeepSpeed's ``save_checkpoint()``. The
Expand All @@ -1251,7 +1251,9 @@ def module_state_dict(self):
assert self._curr_ckpt_path is not None, \
"PipelineEngine expects module_state_dict() to be called from save_checkpoint()"

self.module.save_state_dict(self._curr_ckpt_path, checkpoint_engine=self.checkpoint_engine)
self.module.save_state_dict(self._curr_ckpt_path,
checkpoint_engine=self.checkpoint_engine,
exclude_frozen_params=exclude_frozen_parameters)
return None

def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None):
Expand Down
29 changes: 21 additions & 8 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .topology import PipeDataParallelTopology, PipelineParallelGrid
from deepspeed.runtime.state_dict_factory import SDLoaderFactory
from deepspeed.accelerator import get_accelerator
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save


class PipelineError(Exception):
Expand Down Expand Up @@ -259,6 +260,20 @@ def _build(self):
for p in self.parameters():
p.ds_pipe_replicated = False

def _get_frozen_parameter_names(self, layer):
""" Get names of frozen parameters in the layer.
Returns:
A list of frozen parameter names
"""
if isinstance(layer, LayerSpec):
l = layer.build()
return [n for n, p in l.named_parameters() if not p.requires_grad]
elif isinstance(layer, nn.Module):
return [n for n, p in layer.named_parameters() if not p.requires_grad]

return []

def _count_layer_params(self):
"""Count the trainable parameters in individual layers.
Expand Down Expand Up @@ -544,7 +559,7 @@ def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx):
ckpt_files.sort()
return ckpt_files

def save_state_dict(self, save_dir, checkpoint_engine):
def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=False):
# Processes having the same model parallel rank on different data parallel instances
# have identical layer weights. We can distribute the task of saving the layer weights
# among the data parallel ranks. For example, if a pipeline stage has 9 layers and
Expand All @@ -569,14 +584,12 @@ def save_state_dict(self, save_dir, checkpoint_engine):
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
if not hasattr(layer, 'state_dict'):
continue
# We pass cloned tensors to torch.save() to avoid checkpoint bloat which occurs because torch.save()
# saves the underlying storage rather than the slice of the storage corresponding to individual tensors.
# This is a problem in DeepSpeed because we often allocate tensors using slices of large flattened buffers.
# Tensor cloning helps to avoid this problem because the storage of cloned tensors are closer to the true size.
# It is expected that the garbage collector will reclaim the cloned tensor storage to avoid memory bloat.
# See https://pytorch.org/docs/stable/notes/serialization.html#preserve-storage-sharing

orig_state_dict = layer.state_dict()
final_state_dict = type(orig_state_dict)({k: v.clone() for k, v in orig_state_dict.items()})
if exclude_frozen_params:
for n in self._get_frozen_parameter_names(layer):
del orig_state_dict[n]
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
checkpoint_engine.save(final_state_dict, model_ckpt_path)

def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
Expand Down
50 changes: 49 additions & 1 deletion tests/unit/checkpoint/test_zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import deepspeed
from deepspeed.ops.op_builder import CPUAdamBuilder
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save, get_model_ckpt_name_for_rank
from deepspeed.accelerator import get_accelerator

from unit.common import DistributedTest, DistributedFixture
Expand Down Expand Up @@ -472,6 +472,54 @@ def test_load_module_only(self, tmpdir, zero_stage):

checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_module_only=True)

@pytest.mark.parametrize('zero_stage', [1, 2])
def test_save_exclude_frozen_weights(self, tmpdir, zero_stage):
world_size = 1
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"optimizer": {
"type": 'Adam'
},
"fp16": {
"enabled": True,
"initial_scale_power": 8
},
"zero_optimization": {
"stage": zero_stage,
}
}
hidden_dim = 10

model = SimpleFrozenModel(hidden_dim, empty_grad=False)

ds_engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)

# Validate backwards-compatibility of including frozen parameters in checkpoint
all_ckpt_folder = os.path.join(tmpdir, 'all_params')
ds_engine.save_checkpoint(all_ckpt_folder)
all_params_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(all_ckpt_folder, 'global_step0'), '00')
loaded_all_param_model = torch.load(all_params_ckpt_file)['module']
all_param_names = set([n for n, p in model.named_parameters()])
assert set(loaded_all_param_model.keys()) == all_param_names

# Validate exclusion of frozen parameters
trainable_ckpt_folder = os.path.join(tmpdir, 'no_frozen_params')
ds_engine.save_checkpoint(trainable_ckpt_folder, exclude_frozen_parameters=True)

trainable_ckpt_file = get_model_ckpt_name_for_rank(os.path.join(trainable_ckpt_folder, 'global_step0'), '00')

# Excluding frozen parameters should reduce checkpoint size
assert os.path.getsize(all_params_ckpt_file) > os.path.getsize(trainable_ckpt_file)

loaded_trainable_param_model = torch.load(trainable_ckpt_file)['module']
frozen_param_names = set([n for n, p in model.named_parameters() if not p.requires_grad])
loaded_trainable_param_names = set(loaded_trainable_param_model.keys())
overlap_names = set.intersection(loaded_trainable_param_names, frozen_param_names)
assert len(overlap_names) == 0

trainable_param_names = set([n for n, p in model.named_parameters() if p.requires_grad])
assert loaded_trainable_param_names == trainable_param_names


class TestSaveTensorClone(DistributedTest):
world_size = 1
Expand Down

0 comments on commit 0a0819b

Please sign in to comment.