diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index aa71ddd33e59..af5c9b9f1d12 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -195,7 +195,17 @@ g_cpu_optimizer_state_dict = {} -def _save_func(obj, path, saved_signal_path, protocol): +def _save_func(obj, name_mapping, path, saved_signal_path, protocol): + if isinstance(obj, dict): + for k, v in obj.items(): + if k == "master_weights" and isinstance(v, dict): + for kk, vv in v.items(): + if isinstance(vv, paddle.Tensor): + vv.name = name_mapping["master_weights"][kk] + else: + if k in name_mapping and isinstance(v, paddle.Tensor): + v.name = name_mapping[k] + paddle.save(obj, path, protocol) # dump savd_siganl with open(saved_signal_path, mode="w+") as f: @@ -228,17 +238,18 @@ def clear_async_save_task_queue(): def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4): global g_cpu_optimizer_state_dict g_cpu_optimizer_state_dict.clear() + name_mapping = {"master_weights": {}} for k, v in optimizer_state_dict.items(): if k == "master_weights": g_cpu_optimizer_state_dict[k] = {} for kk, vv in v.items(): - tensor_name = vv.name g_cpu_optimizer_state_dict[k][kk] = vv.pin_memory() - g_cpu_optimizer_state_dict[k][kk].name = tensor_name + name_mapping[k][kk] = vv.name elif k == "LR_Scheduler": g_cpu_optimizer_state_dict[k] = copy.deepcopy(v) else: g_cpu_optimizer_state_dict[k] = v.pin_memory() + name_mapping[k] = v.name paddle.device.synchronize() clear_async_save_task_queue() @@ -248,7 +259,9 @@ def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol def start_process(): nonlocal attempt try: - p = ctx.Process(target=_save_func, args=(g_cpu_optimizer_state_dict, path, saved_signal_path, protocol)) + p = ctx.Process( + target=_save_func, args=(g_cpu_optimizer_state_dict, name_mapping, path, saved_signal_path, protocol) + ) p.start() return p except Exception as e: diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index e8c06568a6e7..b6619151eff6 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1171,15 +1171,23 @@ def split_parallel_config(parallel_config): # sync_param_name = [""] matches any parameter name. # If sync_param, sync_grad and sync_moment are not set, the default value in Paddle is : # sync_param = True, sync_grad = False, sync_moment = False, sync_param_name = ["embedding", "layer_norm", ".b_"]. + + if sync_param or sync_grad or sync_moment: + logger.info("setting sync_param_name") + strategy.sync_param_name = [""] + if sync_param: + logger.info("setting sync_param") strategy.hybrid_configs["mp_configs"].sync_param = True - strategy.hybrid_configs["mp_configs"].sync_param_name = [""] + if sync_grad: + logger.info("setting sync_grad") strategy.hybrid_configs["mp_configs"].sync_grad = True - strategy.hybrid_configs["mp_configs"].sync_grad_name = [""] + if sync_moment: + logger.info("setting sync_moment") strategy.hybrid_configs["mp_configs"].sync_moment = True - strategy.hybrid_configs["mp_configs"].sync_moment_name = [""] + except: warnings.warn( "The enable_mp_async_allreduce, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add are not supported " diff --git a/paddlenlp/trainer/utils/reshard/sharding_v2.py b/paddlenlp/trainer/utils/reshard/sharding_v2.py index c7ae3df4fd2c..d5df4666ab09 100644 --- a/paddlenlp/trainer/utils/reshard/sharding_v2.py +++ b/paddlenlp/trainer/utils/reshard/sharding_v2.py @@ -14,11 +14,14 @@ import numpy as np import paddle +import paddle.distributed.fleet as fleet from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import ( HybridParallelOptimizer, ) from paddle.distributed.fleet.model import PipelineParallel +from paddlenlp.utils.log import logger + from ....transformers.model_utils import unwrap_optimizer try: @@ -29,6 +32,9 @@ DygraphShardingOptimizerV2 = None +from paddle.distributed.communication.reduce import ReduceOp + + def shard(node_model_state, model, optimizer, hcg): assert DygraphShardingOptimizerV2 is not None group = hcg.get_sharding_parallel_group() @@ -137,7 +143,7 @@ def slice_tensor(tensor, begin, end): return tensor[begin:end] -def collect_split_info(optimizer, model): +def collect_split_info(optimizer, model, only_return_lengths=False): split_infos = {} def gather_infos(comm_buffer): @@ -146,7 +152,13 @@ def gather_infos(comm_buffer): padded_size = v._padded_size buffer_size = v._param_buffer._numel() has_slice_grad = v._slice_grad is not None - split_infos[k] = (index, padded_size, buffer_size, has_slice_grad) + if only_return_lengths: + if v._param_begin < v._param_end: + split_infos[k] = v._param_end - v._param_begin + else: + split_infos[k] = None + else: + split_infos[k] = (index, padded_size, buffer_size, has_slice_grad) if isinstance(model, PipelineParallel) and model._sharding_comm_overlap > 0: optimizer = unwrap_optimizer(optimizer, HybridParallelOptimizer) @@ -167,6 +179,51 @@ def gather_infos(comm_buffer): return split_infos +def is_matched_optimizer_state_dict(opt_state_dict, optimizer, model, hcg=None, need_allgather=True): + split_infos = collect_split_info(optimizer, model, only_return_lengths=True) + master_weights = opt_state_dict.get("master_weights", None) + + def get_matched_length(name): + if master_weights and name in master_weights: + tensor = master_weights[name] + else: + moment_name = name + "_moment1_0" + if moment_name not in opt_state_dict: + return None + + tensor = opt_state_dict[moment_name] + if isinstance(tensor, (list, tuple)): + assert len(tensor) == 2, tensor + assert isinstance(tensor[0], str), tensor[0] + tensor = tensor[1] + shape = tensor.shape + assert len(shape) == 1, shape + length = shape[0] + return length + + is_matched = 1 + for k, length in split_infos.items(): + matched_length = get_matched_length(k) + if length != matched_length: + is_matched = 0 + break + + if need_allgather: + if hcg is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_sharding_parallel_group() + if group is not None and group.nranks > 1: + x = paddle.to_tensor([is_matched], dtype=paddle.int32) + paddle.distributed.stream.all_reduce(x, op=ReduceOp.MIN, group=group, sync_op=True, use_calc_stream=True) + global_is_matched = int(x.numpy()[0]) + else: + global_is_matched = is_matched + + global_is_matched = True if global_is_matched else False + logger.info(f"Sharding reshard checkpoint: local_match = {is_matched} , global_match = {global_is_matched}") + return global_is_matched + + def is_bata(name): if "_beta1_pow_acc_" in name: return True diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index 2d3d34c82d28..59ad5e5e578e 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -40,7 +40,7 @@ from paddlenlp.utils.log import logger from . import reshard as reshard_util -from .reshard import SHARDING_STRATEGY_V1, pp_reshard +from .reshard import SHARDING_STRATEGY_V1, SHARDING_STRATEGY_V2, pp_reshard # Name of the files used for checkpointing TRAINING_ARGS_NAME = "training_args.bin" @@ -204,10 +204,21 @@ def _load_optimizer_state_of_one_shard(self, checkpoint, base_opt_name, optimize path = os.path.join(checkpoint, optimizer_name) logger.info(f"load optimizer state from {path}") if os.path.isfile(path): - return paddlenlp_load(path, map_location="cpu") + return self._modify_ckpt_for_compatibility(paddlenlp_load(path, map_location="cpu")) logger.info(f"{path} not exists") return None + def _modify_ckpt_for_compatibility(self, ckpt): + master_weights = ckpt.get("master_weights", None) + if master_weights: + for k, v in master_weights.items(): + assert isinstance(v, paddle.Tensor), v + if not v.name.startswith(k): + new_name = k + "_fp32_master_0" + logger.info(f"Modify master weights {v.name} -> {new_name}") + v.name = new_name + return ckpt + def _need_reshard(self, checkpoint): if self._need_reshard_pp(checkpoint): return True @@ -253,10 +264,6 @@ def _need_reshard_pp(self, checkpoint): def load_optimizer_state_with_reshard(self, checkpoint, base_opt_name, model_wrapped): """load state_dict of multiple shard from_checkpoint, Only load model state dict.""" - if not self._need_reshard(checkpoint): - logger.info("do not need reshard") - return self._load_optimizer_state_of_one_shard(checkpoint, base_opt_name, self.args.optimizer_name_suffix) - logger.info("reshard optimizer state") parallel_config = self._load_distributed_strategy(checkpoint) sharding_meta = self._load_sharding_meta(checkpoint, 0) pp_degree = parallel_config["pp_degree"] @@ -276,6 +283,26 @@ def load_optimizer_state_with_reshard(self, checkpoint, base_opt_name, model_wra cur_sharding_degree = self.args.sharding_parallel_degree cur_sharding_strategy = reshard_util.get_sharding_strategy(self.optimizer) + if not self._need_reshard(checkpoint): + one_shard_opt_state_dict = self._load_optimizer_state_of_one_shard( + checkpoint, base_opt_name, self.args.optimizer_name_suffix + ) + + if sharding_strategy == SHARDING_STRATEGY_V2 and cur_sharding_strategy == SHARDING_STRATEGY_V2: + is_matched = reshard_util.sharding_v2.is_matched_optimizer_state_dict( + one_shard_opt_state_dict, self.optimizer, model_wrapped + ) + else: + is_matched = True + + if is_matched: + logger.info("do not need reshard") + return one_shard_opt_state_dict + else: + one_shard_opt_state_dict = None + + logger.info("reshard optimizer state") + def load_model_slices(): model_state = reshard_util.NodeModelState() for j in range(self.args.pipeline_parallel_rank, pp_degree, cur_pp_degree): @@ -283,9 +310,14 @@ def load_model_slices(): assert "structure_name_mapping" in cur_sharding_meta structure_name_map = cur_sharding_meta["structure_name_mapping"] for i in range(self.args.sharding_parallel_rank, sharding_degree, cur_sharding_degree): - tmp = self._load_optimizer_state_of_one_shard( - checkpoint, base_opt_name, self.args.sharded_name_suffix(i, j) - ) + sharded_name_suffix = self.args.sharded_name_suffix(i, j) + if one_shard_opt_state_dict is None: + tmp = self._load_optimizer_state_of_one_shard(checkpoint, base_opt_name, sharded_name_suffix) + else: + assert ( + self.args.optimizer_name_suffix == sharded_name_suffix + ), f"{self.args.optimizer_name_suffix} vs {sharded_name_suffix}" + tmp = one_shard_opt_state_dict node_model_state_tmp = reshard_util.NodeModelState() node_model_state_tmp.add_opts(tmp) node_model_state_tmp.pack_keys(structure_name_map)