Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick/fast_safe_open #8458

Merged
merged 4 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from paddlenlp.transformers.model_utils import (
PretrainedModel,
_load_state_dict_into_model,
faster_set_state_dict,
get_parameter_dtype,
load_state_dict,
unwrap_model,
Expand Down Expand Up @@ -65,9 +66,10 @@
from paddlenlp.utils.nested import nested_copy, nested_copy_place

if is_safetensors_available():
from safetensors import safe_open
# from safetensors import safe_open
from safetensors.numpy import save_file as safe_save_file

from paddlenlp.utils.safetensors import fast_safe_open as safe_open

FP32_MASTER = "fp32_master_0"
optimizer_scalar_name = [
Expand All @@ -91,6 +93,11 @@
async_save_queue = []


DEST_PLACE = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
DEST_PLACE = paddle.CUDAPinnedPlace()


class UnifiedCheckpointOption(ExplicitEnum):
"""
"- skip_save_model_weight: do not save model weights when the masters weight exist\n"
Expand Down Expand Up @@ -196,7 +203,6 @@ def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str,
Returns:
None
"""

if paddle.distributed.get_world_size() <= 1:
load_single_card_checkpoint(args, model, resume_from_checkpoint)
return
Expand All @@ -222,7 +228,6 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa
pretrained_model_name_or_path=resume_from_checkpoint,
index_filename=os.path.join(resume_from_checkpoint, index_filename),
)

loaded_keys = sharded_metadata["all_checkpoint_keys"]

model_state_dict = get_expected_state_dict(model)
Expand Down Expand Up @@ -266,7 +271,9 @@ def _remove_unused_keys(
else:
tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True)
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
state_dict = load_state_dict(shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys)
state_dict = load_state_dict(
shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys, device="expected"
)

if not pre_tensor_parallel_split:
# Since we load all keys but we only need one of pipeline stages
Expand All @@ -279,11 +286,12 @@ def _remove_unused_keys(
None, model.config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1
)

error_msgs += _load_state_dict_into_model(model, state_dict, "")
# error_msgs += _load_state_dict_into_model(model, state_dict, "")
error_msgs += faster_set_state_dict(model, state_dict, strict_dtype=False)

# force memory release
del state_dict
gc.collect()
# gc.collect()

if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
Expand Down Expand Up @@ -337,6 +345,7 @@ def unified_checkpoint_into_shards(
tp_actions = model_to_save.get_tensor_parallel_convert_actions(
model_to_save.config, state_dict.keys(), is_split=False, ignore_error=True
)
logger.info("Unified model tensor parallel weights in shards")
state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys)

# build index json file
Expand Down Expand Up @@ -490,6 +499,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]

if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")

Expand Down Expand Up @@ -537,10 +547,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys)

# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
state_dict = load_state_dict(shard_file, tp_actions, expected_keys)
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected")
else:
# for pipeline model, we don't need to use tp_actions
state_dict = load_state_dict(shard_file, None, expected_keys)
state_dict = load_state_dict(shard_file, None, expected_keys, device="expected")

returned_state_dict.update(state_dict)
# force memory release
Expand All @@ -553,7 +563,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
state_dict_master_weight = load_resolved_archive_file(
resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True
)

# rename optimizer param
for key in list(state_dict_optim.keys()):
key_name = key.split("/")
Expand All @@ -562,13 +571,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
returned_optim_state_dict[key_name] = state_dict_optim[key]
returned_optim_state_dict[key_name] = state_dict_optim.pop(key)
returned_optim_state_dict[key_name].name = key_name

if has_master_weights:
for key in list(state_dict_master_weight.keys()):
static_name = struct2static_name_mappings[key]
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight[key]
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])

returned_optim_state_dict = nested_copy_place(
Expand Down Expand Up @@ -640,6 +649,7 @@ def unified_optimizer_into_shards(
tp_actions = model.get_tensor_parallel_convert_actions(
model.config, model_keys, is_split=False, ignore_error=True
)
logger.info("Unified optimizer tensor parallel in shards")
optim_state_dict = merge_tensor_parallel_for_optimizer(
optim_state_dict,
tp_actions,
Expand All @@ -648,6 +658,7 @@ def unified_optimizer_into_shards(
paddle.device.cuda.empty_cache()

if master_weights is not None:
logger.info("Unified master weight tensor parallel in shards")
master_weights = merge_tensor_parallel_for_optimizer(
master_weights,
tp_actions,
Expand Down Expand Up @@ -703,7 +714,6 @@ def unified_optimizer_into_shards(
def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False):
index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=False)
index_filename = os.path.join(resume_from_checkpoint, index_filename)

# Find index json file and distribute this file in global group.
if distributed_isfile(index_filename):
distributed_file(index_filename)
Expand Down Expand Up @@ -1605,7 +1615,9 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False):
tp_group = hcg.get_model_parallel_group()
pp_group = hcg.get_pipe_parallel_group()

logger.info("Unified checkpoint generating sharded_index json files.")
logger.info(
f"Unified checkpoint: generating sharded_index json files for {'optimizer or master weight' if is_optimizer else 'model weight'}."
)

if tp_group.nranks > 1:
dist.all_gather_object(index_file_list, index_file, tp_group)
Expand Down Expand Up @@ -1714,8 +1726,6 @@ def filter_params(model_to_save, state_dict, is_optimizer=False):


def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
logger.info("Unified checkpoint merge tensor parallel in shards")

hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
tp_rank = tp_group.rank
Expand All @@ -1741,7 +1751,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
action = tp_actions.pop(key)
tensor = action(ret) if is_dst else None
else:
tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None

if is_dst:
state_dict_to_save[key] = tensor
Expand All @@ -1754,8 +1764,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):


def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys):
logger.info("Unified optimizer tensor parallel in shards")

# Core function for UC
hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
tp_rank = tp_group.rank
Expand All @@ -1773,15 +1782,13 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
if model_key in tp_actions:
# for example: beta1, beta2
if tensor.numel().item() == 1:
tensor = (
tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
) # Need broadcast when loaded
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions[model_key]
tensor = action(ret) if is_dst else None
else:
tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None

if is_dst:
state_dict_to_save[filter_keys[i]] = tensor
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2422,6 +2422,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
self.runtime_timer.stop()
return

logger.info("Loading optimizer and scheduler...")
if (not self.args.should_load_sharding_stage1_model) and self.args.ignore_load_lr_and_optim:
self.runtime_timer.stop()
return
Expand Down
16 changes: 13 additions & 3 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,12 @@ def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2):

if isinstance(weight_list[0], np.ndarray):
return np.concatenate([reorder[i] for i in index], axis=axis)
else:
tensor = paddle.concat([reorder[i] for i in index], axis=axis)

return paddle.concat([reorder[i] for i in index], axis=axis)._copy_to(paddle.CPUPlace(), False)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor


def naive_fuse_split_tp(
Expand Down Expand Up @@ -361,12 +365,18 @@ def normal_fuse_merge_tp(weight_list, is_column=True):
if isinstance(weight_list[0], np.ndarray):
return np.concatenate(weight_list, axis=-1)
else:
return paddle.concat(weight_list, axis=-1)._copy_to(paddle.CPUPlace(), False)
tensor = paddle.concat(weight_list, axis=-1)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor
else:
if isinstance(weight_list[0], np.ndarray):
return np.concatenate(weight_list, axis=0)
else:
return paddle.concat(weight_list, axis=0)._copy_to(paddle.CPUPlace(), False)
tensor = paddle.concat(weight_list, axis=0)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor


def normal_fuse_split_tp(weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True):
Expand Down
51 changes: 35 additions & 16 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,13 @@ def unwrap_optimizer(optimizer, optimizer_instances=()):

if is_safetensors_available():

from safetensors import safe_open
from safetensors.numpy import load_file as safe_load_file
# from safetensors import safe_open
# from safetensors.numpy import load_file as safe_load_file
from safetensors.numpy import save_file as safe_save_file

from paddlenlp.utils.safetensors import fast_load_file as safe_load_file
from paddlenlp.utils.safetensors import fast_safe_open as safe_open


def prune_linear_layer(layer: nn.Linear, index: paddle.Tensor, dim: int = 0) -> nn.Linear:
"""
Expand Down Expand Up @@ -313,7 +316,7 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:


def load_state_dict(
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None, device="cpu"
):
"""
Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
Expand Down Expand Up @@ -346,11 +349,16 @@ def load_state_dict(
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
else:
weight = py_safe_slice_[:]
if device == "expected":
with device_guard():
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
state_dict[key] = weight

for k in list(state_dict.keys()):
with device_guard():
state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True)
if device == "cpu":
for k in list(state_dict.keys()):
with device_guard():
state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True)

return state_dict

Expand Down Expand Up @@ -672,8 +680,10 @@ def load_sharded_checkpoint(model, folder, variant=None, strict=True, prefer_saf
return missing_keys, unexpected_keys


def faster_set_state_dict(model, state_dict):
def faster_set_state_dict(model, state_dict, strict_dtype=True):
# the state_dict will be destroied.
unused_keys = set(state_dict.keys())
unset_keys = set(model.state_dict().keys())
with paddle.no_grad():
for k, v in model.state_dict().items():
if k in state_dict:
Expand All @@ -683,8 +693,10 @@ def faster_set_state_dict(model, state_dict):
f"faster_set_state_dict need state dict with paddle.Tensor, but got {type(v_new)}"
)
# 2. cast param / Tensor to dtype
#
if v.dtype != v_new.dtype:
raise ValueError(f"for key: {k}, expect dtype {v.dtype}, but got {v_new.dtype}")
if strict_dtype or (not v.is_floating_point() or not v_new.is_floating_point()):
raise ValueError(f"for key: {k}, expect dtype {v.dtype}, but got {v_new.dtype}")
# check shape
if list(v.shape) != list(v_new.shape):
raise ValueError(f"for key: {k}, expect shape {v.shape}, but got {v_new.shape}")
Expand All @@ -700,9 +712,22 @@ def faster_set_state_dict(model, state_dict):
else:
new_t = v_new

if not strict_dtype and v.dtype != new_t.dtype:
new_t = new_t.astype(v.dtype)

# 4. share Tensor to origin param / Tensor
src_tensor = new_t.value().get_tensor()
dst_tensor._share_data_with(src_tensor)
unset_keys.remove(k)
unused_keys.remove(k)

error_msgs = []
# if len(unset_keys) > 0:
# error_msgs.append(f"Those weight of model is not initialized: {list(unset_keys)}")
if len(unused_keys) > 0:
error_msgs.append(f"Those state dict keys are not using in model: {list(unused_keys)}")

return error_msgs


def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
Expand Down Expand Up @@ -734,22 +759,16 @@ def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
def is_0d_or_1d(tensor):
return len(tensor.shape) == 0 or list(tensor.shape) == [1]

expected_place = paddle.framework._current_expected_place()
for key, value in model_to_load.state_dict().items():
if key in state_dict:
if key in list(state_dict.keys()):
if isinstance(state_dict[key], np.ndarray):
raise ValueError(
"convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, plase convert numpy.ndarray to paddle.Tensor"
)
# confirm parameter cast is executed on the same device as model
# TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
if state_dict[key].is_floating_point() and state_dict[key].dtype != value.dtype:
value_pop = state_dict.pop(key)
value_new_place = (
value_pop if value_pop.place == expected_place else value_pop._copy_to(expected_place, False)
)
state_dict[key] = paddle.cast(value_new_place, value.dtype)._copy_to(value_pop.place, False)
del value_new_place
state_dict[key] = paddle.cast(state_dict.pop(key), value.dtype)
# unified 0d and 1d tensor
if is_0d_or_1d(value) and is_0d_or_1d(state_dict[key]):
if list(value.shape) != list(state_dict[key].shape):
Expand Down
Loading
Loading