Skip to content

Commit

Permalink
Cherry-Pick fast_safe_open (#8458)
Browse files Browse the repository at this point in the history
* [Performance] Optimize unified checkpoint save/load speed. (#8204)

* opt unified checkpoint save/load speed.
  • Loading branch information
ZHUI authored May 20, 2024
1 parent fc860a3 commit 08898bf
Show file tree
Hide file tree
Showing 10 changed files with 490 additions and 41 deletions.
51 changes: 29 additions & 22 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from paddlenlp.transformers.model_utils import (
PretrainedModel,
_load_state_dict_into_model,
faster_set_state_dict,
get_parameter_dtype,
load_state_dict,
unwrap_model,
Expand Down Expand Up @@ -65,9 +66,10 @@
from paddlenlp.utils.nested import nested_copy, nested_copy_place

if is_safetensors_available():
from safetensors import safe_open
# from safetensors import safe_open
from safetensors.numpy import save_file as safe_save_file

from paddlenlp.utils.safetensors import fast_safe_open as safe_open

FP32_MASTER = "fp32_master_0"
optimizer_scalar_name = [
Expand All @@ -91,6 +93,11 @@
async_save_queue = []


DEST_PLACE = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
DEST_PLACE = paddle.CUDAPinnedPlace()


class UnifiedCheckpointOption(ExplicitEnum):
"""
"- skip_save_model_weight: do not save model weights when the masters weight exist\n"
Expand Down Expand Up @@ -196,7 +203,6 @@ def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str,
Returns:
None
"""

if paddle.distributed.get_world_size() <= 1:
load_single_card_checkpoint(args, model, resume_from_checkpoint)
return
Expand All @@ -222,7 +228,6 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa
pretrained_model_name_or_path=resume_from_checkpoint,
index_filename=os.path.join(resume_from_checkpoint, index_filename),
)

loaded_keys = sharded_metadata["all_checkpoint_keys"]

model_state_dict = get_expected_state_dict(model)
Expand Down Expand Up @@ -266,7 +271,9 @@ def _remove_unused_keys(
else:
tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True)
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
state_dict = load_state_dict(shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys)
state_dict = load_state_dict(
shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys, device="expected"
)

if not pre_tensor_parallel_split:
# Since we load all keys but we only need one of pipeline stages
Expand All @@ -279,11 +286,12 @@ def _remove_unused_keys(
None, model.config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1
)

error_msgs += _load_state_dict_into_model(model, state_dict, "")
# error_msgs += _load_state_dict_into_model(model, state_dict, "")
error_msgs += faster_set_state_dict(model, state_dict, strict_dtype=False)

# force memory release
del state_dict
gc.collect()
# gc.collect()

if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
Expand Down Expand Up @@ -337,6 +345,7 @@ def unified_checkpoint_into_shards(
tp_actions = model_to_save.get_tensor_parallel_convert_actions(
model_to_save.config, state_dict.keys(), is_split=False, ignore_error=True
)
logger.info("Unified model tensor parallel weights in shards")
state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys)

# build index json file
Expand Down Expand Up @@ -490,6 +499,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]

if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")

Expand Down Expand Up @@ -537,10 +547,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys)

# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
state_dict = load_state_dict(shard_file, tp_actions, expected_keys)
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected")
else:
# for pipeline model, we don't need to use tp_actions
state_dict = load_state_dict(shard_file, None, expected_keys)
state_dict = load_state_dict(shard_file, None, expected_keys, device="expected")

returned_state_dict.update(state_dict)
# force memory release
Expand All @@ -553,7 +563,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
state_dict_master_weight = load_resolved_archive_file(
resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True
)

# rename optimizer param
for key in list(state_dict_optim.keys()):
key_name = key.split("/")
Expand All @@ -562,13 +571,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
returned_optim_state_dict[key_name] = state_dict_optim[key]
returned_optim_state_dict[key_name] = state_dict_optim.pop(key)
returned_optim_state_dict[key_name].name = key_name

if has_master_weights:
for key in list(state_dict_master_weight.keys()):
static_name = struct2static_name_mappings[key]
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight[key]
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])

returned_optim_state_dict = nested_copy_place(
Expand Down Expand Up @@ -640,6 +649,7 @@ def unified_optimizer_into_shards(
tp_actions = model.get_tensor_parallel_convert_actions(
model.config, model_keys, is_split=False, ignore_error=True
)
logger.info("Unified optimizer tensor parallel in shards")
optim_state_dict = merge_tensor_parallel_for_optimizer(
optim_state_dict,
tp_actions,
Expand All @@ -648,6 +658,7 @@ def unified_optimizer_into_shards(
paddle.device.cuda.empty_cache()

if master_weights is not None:
logger.info("Unified master weight tensor parallel in shards")
master_weights = merge_tensor_parallel_for_optimizer(
master_weights,
tp_actions,
Expand Down Expand Up @@ -703,7 +714,6 @@ def unified_optimizer_into_shards(
def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False):
index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=False)
index_filename = os.path.join(resume_from_checkpoint, index_filename)

# Find index json file and distribute this file in global group.
if distributed_isfile(index_filename):
distributed_file(index_filename)
Expand Down Expand Up @@ -1605,7 +1615,9 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False):
tp_group = hcg.get_model_parallel_group()
pp_group = hcg.get_pipe_parallel_group()

logger.info("Unified checkpoint generating sharded_index json files.")
logger.info(
f"Unified checkpoint: generating sharded_index json files for {'optimizer or master weight' if is_optimizer else 'model weight'}."
)

if tp_group.nranks > 1:
dist.all_gather_object(index_file_list, index_file, tp_group)
Expand Down Expand Up @@ -1714,8 +1726,6 @@ def filter_params(model_to_save, state_dict, is_optimizer=False):


def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
logger.info("Unified checkpoint merge tensor parallel in shards")

hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
tp_rank = tp_group.rank
Expand All @@ -1741,7 +1751,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
action = tp_actions.pop(key)
tensor = action(ret) if is_dst else None
else:
tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None

if is_dst:
state_dict_to_save[key] = tensor
Expand All @@ -1754,8 +1764,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):


def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys):
logger.info("Unified optimizer tensor parallel in shards")

# Core function for UC
hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
tp_rank = tp_group.rank
Expand All @@ -1773,15 +1782,13 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
if model_key in tp_actions:
# for example: beta1, beta2
if tensor.numel().item() == 1:
tensor = (
tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
) # Need broadcast when loaded
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions[model_key]
tensor = action(ret) if is_dst else None
else:
tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None

if is_dst:
state_dict_to_save[filter_keys[i]] = tensor
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2419,6 +2419,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
self.runtime_timer.stop()
return

logger.info("Loading optimizer and scheduler...")
if (not self.args.should_load_sharding_stage1_model) and self.args.ignore_load_lr_and_optim:
self.runtime_timer.stop()
return
Expand Down
16 changes: 13 additions & 3 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,12 @@ def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2):

if isinstance(weight_list[0], np.ndarray):
return np.concatenate([reorder[i] for i in index], axis=axis)
else:
tensor = paddle.concat([reorder[i] for i in index], axis=axis)

return paddle.concat([reorder[i] for i in index], axis=axis)._copy_to(paddle.CPUPlace(), False)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor


def naive_fuse_split_tp(
Expand Down Expand Up @@ -361,12 +365,18 @@ def normal_fuse_merge_tp(weight_list, is_column=True):
if isinstance(weight_list[0], np.ndarray):
return np.concatenate(weight_list, axis=-1)
else:
return paddle.concat(weight_list, axis=-1)._copy_to(paddle.CPUPlace(), False)
tensor = paddle.concat(weight_list, axis=-1)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor
else:
if isinstance(weight_list[0], np.ndarray):
return np.concatenate(weight_list, axis=0)
else:
return paddle.concat(weight_list, axis=0)._copy_to(paddle.CPUPlace(), False)
tensor = paddle.concat(weight_list, axis=0)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor


def normal_fuse_split_tp(weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True):
Expand Down
51 changes: 35 additions & 16 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,13 @@ def unwrap_optimizer(optimizer, optimizer_instances=()):

if is_safetensors_available():

from safetensors import safe_open
from safetensors.numpy import load_file as safe_load_file
# from safetensors import safe_open
# from safetensors.numpy import load_file as safe_load_file
from safetensors.numpy import save_file as safe_save_file

from paddlenlp.utils.safetensors import fast_load_file as safe_load_file
from paddlenlp.utils.safetensors import fast_safe_open as safe_open


def prune_linear_layer(layer: nn.Linear, index: paddle.Tensor, dim: int = 0) -> nn.Linear:
"""
Expand Down Expand Up @@ -313,7 +316,7 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:


def load_state_dict(
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None, device="cpu"
):
"""
Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
Expand Down Expand Up @@ -346,11 +349,16 @@ def load_state_dict(
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
else:
weight = py_safe_slice_[:]
if device == "expected":
with device_guard():
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
state_dict[key] = weight

for k in list(state_dict.keys()):
with device_guard():
state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True)
if device == "cpu":
for k in list(state_dict.keys()):
with device_guard():
state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True)

return state_dict

Expand Down Expand Up @@ -672,8 +680,10 @@ def load_sharded_checkpoint(model, folder, variant=None, strict=True, prefer_saf
return missing_keys, unexpected_keys


def faster_set_state_dict(model, state_dict):
def faster_set_state_dict(model, state_dict, strict_dtype=True):
# the state_dict will be destroied.
unused_keys = set(state_dict.keys())
unset_keys = set(model.state_dict().keys())
with paddle.no_grad():
for k, v in model.state_dict().items():
if k in state_dict:
Expand All @@ -683,8 +693,10 @@ def faster_set_state_dict(model, state_dict):
f"faster_set_state_dict need state dict with paddle.Tensor, but got {type(v_new)}"
)
# 2. cast param / Tensor to dtype
#
if v.dtype != v_new.dtype:
raise ValueError(f"for key: {k}, expect dtype {v.dtype}, but got {v_new.dtype}")
if strict_dtype or (not v.is_floating_point() or not v_new.is_floating_point()):
raise ValueError(f"for key: {k}, expect dtype {v.dtype}, but got {v_new.dtype}")
# check shape
if list(v.shape) != list(v_new.shape):
raise ValueError(f"for key: {k}, expect shape {v.shape}, but got {v_new.shape}")
Expand All @@ -700,9 +712,22 @@ def faster_set_state_dict(model, state_dict):
else:
new_t = v_new

if not strict_dtype and v.dtype != new_t.dtype:
new_t = new_t.astype(v.dtype)

# 4. share Tensor to origin param / Tensor
src_tensor = new_t.value().get_tensor()
dst_tensor._share_data_with(src_tensor)
unset_keys.remove(k)
unused_keys.remove(k)

error_msgs = []
# if len(unset_keys) > 0:
# error_msgs.append(f"Those weight of model is not initialized: {list(unset_keys)}")
if len(unused_keys) > 0:
error_msgs.append(f"Those state dict keys are not using in model: {list(unused_keys)}")

return error_msgs


def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
Expand Down Expand Up @@ -734,22 +759,16 @@ def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
def is_0d_or_1d(tensor):
return len(tensor.shape) == 0 or list(tensor.shape) == [1]

expected_place = paddle.framework._current_expected_place()
for key, value in model_to_load.state_dict().items():
if key in state_dict:
if key in list(state_dict.keys()):
if isinstance(state_dict[key], np.ndarray):
raise ValueError(
"convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, plase convert numpy.ndarray to paddle.Tensor"
)
# confirm parameter cast is executed on the same device as model
# TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
if state_dict[key].is_floating_point() and state_dict[key].dtype != value.dtype:
value_pop = state_dict.pop(key)
value_new_place = (
value_pop if value_pop.place == expected_place else value_pop._copy_to(expected_place, False)
)
state_dict[key] = paddle.cast(value_new_place, value.dtype)._copy_to(value_pop.place, False)
del value_new_place
state_dict[key] = paddle.cast(state_dict.pop(key), value.dtype)
# unified 0d and 1d tensor
if is_0d_or_1d(value) and is_0d_or_1d(state_dict[key]):
if list(value.shape) != list(state_dict[key].shape):
Expand Down
Loading

0 comments on commit 08898bf

Please sign in to comment.