Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save_fp16_model consolidated for zero3 #893

Merged
merged 2 commits into from
Mar 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ def zero_prefetch_bucket_size(self):
def zero_param_persistence_threshold(self):
return self._config.zero_config.param_persistence_threshold

def zero_gather_fp16_weights_on_model_save(self):
return self._config.zero_config.gather_fp16_weights_on_model_save

def fp16_enabled(self):
return self._config.fp16_enabled

Expand Down Expand Up @@ -1714,3 +1717,98 @@ def _save_zero_checkpoint(self, save_path, tag):
torch.save(zero_sd, zero_checkpoint_name)
self._copy_recovery_script(save_path)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))

def _zero3_consolidated_fp16_state_dict(self):
"""

Get a full non-partitioned state_dict with fp16 weights on cpu.

This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but:

1. consolidates the weights from different partitions on gpu0
2. works on one layer at a time to require as little gpu0 memory as possible, by
moving the already consolidated weights to cpu
3. takes care to keep the shared params shared when gradually copying the params to cpu

Returns:
a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks

"""
import deepspeed

if not self.zero_optimization_partition_weights():
raise ValueError("this function requires ZeRO-3 mode")

state_dict = OrderedDict() if torch.distributed.get_rank() == 0 else None
shared_weights = {}

def get_layer_state_dict(module, prefix=""):
# gather one layer at a time to be memory-efficient
with deepspeed.zero.GatheredParameters(list(
module.parameters(recurse=False))):
if torch.distributed.get_rank() == 0:
for name, param in module.named_parameters(recurse=False):
if param is None:
continue
key = prefix + name
# for shared weights we want to make sure not to unshare them when copying to cpu
data_ptr_id = param.storage().data_ptr()
if data_ptr_id in shared_weights:
# shared weights
# print(f"`{key}` is shared with `{shared_weights[data_ptr_id]}`")
state_dict[key] = state_dict[shared_weights[data_ptr_id]]
else:
state_dict[key] = param.detach().cpu()
shared_weights[data_ptr_id] = key
#print(f"param {name} {param.shape}")
#print(f"param {key} {param.shape} {state_dict[key].storage().data_ptr()}")

# now buffers - not sure if need to take care of potentially shared weights here
for name, buf in module.named_buffers(recurse=False):
if buf is not None and name not in module._non_persistent_buffers_set:
state_dict[prefix + name] = buf.detach().cpu()

for name, child in module.named_children():
if child is not None:
get_layer_state_dict(child, prefix + name + ".")

see_memory_usage("before get_layer_state_dict", force=False)
get_layer_state_dict(self.module, prefix="")
see_memory_usage("after get_layer_state_dict", force=False)

return state_dict

def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
r"""Save fp16 model weights

This method saves the fp16 model weights at the desired destination.

Arguments:
save_dir: Required. Directory for saving the model
save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``

Important: all processes must call this method and not just the process with rank 0. It is
because the processes need to work in sync to gather the weights. This method will hang
waiting to synchronize with other processes if it's called just for the process with rank 0.

"""

path = os.path.join(save_dir, save_filename)

if self.zero_optimization_partition_weights():
if self.zero_gather_fp16_weights_on_model_save():
# consolidation is expensive in time and memory and therefore isn't a default
state_dict = self._zero3_consolidated_fp16_state_dict()
else:
# the model will be bogus if not consolidated so don't confuse the user by saving it
logger.info(
f"Did not save the model {path} because `stage3_gather_fp16_weights_on_model_save` is False"
)
return
else:
state_dict = self.module.state_dict()

if torch.distributed.get_rank() == 0:
os.makedirs(save_dir, exist_ok=True)
logger.info(f"Saving model weights to {path}")
torch.save(state_dict, path)
6 changes: 6 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, param_dict):
self.param_persistence_threshold = None
self.max_live_parameters = None
self.max_reuse_distance = None
self.gather_fp16_weights_on_model_save = None

#Stage3 Specific Parameters
self.prefetch_bucket_size = None
Expand Down Expand Up @@ -150,3 +151,8 @@ def _initialize(self, zero_config_dict):
zero_config_dict,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT)

self.gather_fp16_weights_on_model_save = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT)
8 changes: 7 additions & 1 deletion deepspeed/runtime/zero/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD = 'stage3_param_persistence_threshold'
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT = 100000

# gathers params for saving a model - inefficient but is required in certain situations
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_fp16_weights_on_model_save'
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False

ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
Expand Down Expand Up @@ -133,5 +137,7 @@
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE:
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD:
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE:
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT
}