Skip to content

Commit

Permalink
save_fp16_model consolidated for zero3 (#893)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
stas00 and tjruwase authored Mar 27, 2021
1 parent 7531c6b commit 39013dd
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 1 deletion.
98 changes: 98 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ def zero_prefetch_bucket_size(self):
def zero_param_persistence_threshold(self):
return self._config.zero_config.param_persistence_threshold

def zero_gather_fp16_weights_on_model_save(self):
return self._config.zero_config.gather_fp16_weights_on_model_save

def fp16_enabled(self):
return self._config.fp16_enabled

Expand Down Expand Up @@ -1714,3 +1717,98 @@ def _save_zero_checkpoint(self, save_path, tag):
torch.save(zero_sd, zero_checkpoint_name)
self._copy_recovery_script(save_path)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))

def _zero3_consolidated_fp16_state_dict(self):
"""
Get a full non-partitioned state_dict with fp16 weights on cpu.
This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but:
1. consolidates the weights from different partitions on gpu0
2. works on one layer at a time to require as little gpu0 memory as possible, by
moving the already consolidated weights to cpu
3. takes care to keep the shared params shared when gradually copying the params to cpu
Returns:
a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks
"""
import deepspeed

if not self.zero_optimization_partition_weights():
raise ValueError("this function requires ZeRO-3 mode")

state_dict = OrderedDict() if torch.distributed.get_rank() == 0 else None
shared_weights = {}

def get_layer_state_dict(module, prefix=""):
# gather one layer at a time to be memory-efficient
with deepspeed.zero.GatheredParameters(list(
module.parameters(recurse=False))):
if torch.distributed.get_rank() == 0:
for name, param in module.named_parameters(recurse=False):
if param is None:
continue
key = prefix + name
# for shared weights we want to make sure not to unshare them when copying to cpu
data_ptr_id = param.storage().data_ptr()
if data_ptr_id in shared_weights:
# shared weights
# print(f"`{key}` is shared with `{shared_weights[data_ptr_id]}`")
state_dict[key] = state_dict[shared_weights[data_ptr_id]]
else:
state_dict[key] = param.detach().cpu()
shared_weights[data_ptr_id] = key
#print(f"param {name} {param.shape}")
#print(f"param {key} {param.shape} {state_dict[key].storage().data_ptr()}")

# now buffers - not sure if need to take care of potentially shared weights here
for name, buf in module.named_buffers(recurse=False):
if buf is not None and name not in module._non_persistent_buffers_set:
state_dict[prefix + name] = buf.detach().cpu()

for name, child in module.named_children():
if child is not None:
get_layer_state_dict(child, prefix + name + ".")

see_memory_usage("before get_layer_state_dict", force=False)
get_layer_state_dict(self.module, prefix="")
see_memory_usage("after get_layer_state_dict", force=False)

return state_dict

def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
r"""Save fp16 model weights
This method saves the fp16 model weights at the desired destination.
Arguments:
save_dir: Required. Directory for saving the model
save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``
Important: all processes must call this method and not just the process with rank 0. It is
because the processes need to work in sync to gather the weights. This method will hang
waiting to synchronize with other processes if it's called just for the process with rank 0.
"""

path = os.path.join(save_dir, save_filename)

if self.zero_optimization_partition_weights():
if self.zero_gather_fp16_weights_on_model_save():
# consolidation is expensive in time and memory and therefore isn't a default
state_dict = self._zero3_consolidated_fp16_state_dict()
else:
# the model will be bogus if not consolidated so don't confuse the user by saving it
logger.info(
f"Did not save the model {path} because `stage3_gather_fp16_weights_on_model_save` is False"
)
return
else:
state_dict = self.module.state_dict()

if torch.distributed.get_rank() == 0:
os.makedirs(save_dir, exist_ok=True)
logger.info(f"Saving model weights to {path}")
torch.save(state_dict, path)
6 changes: 6 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, param_dict):
self.param_persistence_threshold = None
self.max_live_parameters = None
self.max_reuse_distance = None
self.gather_fp16_weights_on_model_save = None

#Stage3 Specific Parameters
self.prefetch_bucket_size = None
Expand Down Expand Up @@ -150,3 +151,8 @@ def _initialize(self, zero_config_dict):
zero_config_dict,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT)

self.gather_fp16_weights_on_model_save = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT)
8 changes: 7 additions & 1 deletion deepspeed/runtime/zero/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD = 'stage3_param_persistence_threshold'
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT = 100000

# gathers params for saving a model - inefficient but is required in certain situations
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_fp16_weights_on_model_save'
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False

ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
Expand Down Expand Up @@ -133,5 +137,7 @@
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE:
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD:
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE:
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT
}

0 comments on commit 39013dd

Please sign in to comment.