From d1d7bc1eb0c23401bc7f4fb1716d4fa53e26011a Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Wed, 14 Aug 2024 21:01:48 +0800 Subject: [PATCH] update --- paddlenlp/trainer/trainer.py | 9 ++------- paddlenlp/trainer/utils/helper.py | 26 +++++++++++++++++--------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 9ecfb137d688..25eb100e0c1c 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -89,7 +89,6 @@ from ..transformers.model_utils import ( PretrainedModel, _add_variant, - _load_state_dict_into_model, load_sharded_checkpoint, unwrap_model, ) @@ -1164,9 +1163,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): safe_serialization=True, ) if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1: - state_dict = broadcast_dataset_rank0_model(self.model.state_dict()) - if self.args.dataset_rank > 0: - _load_state_dict_into_model(self.model, state_dict, "") + broadcast_dataset_rank0_model(self.model) else: weight_name = PADDLE_WEIGHTS_NAME best_model_path = os.path.join( @@ -1210,9 +1207,7 @@ def _load_best_model_from_peft_checkpoint(self): safe_serialization=True, ) if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1: - state_dict = broadcast_dataset_rank0_model(self.model.get_trainable_state_dict()) - if self.args.dataset_rank > 0: - _load_state_dict_into_model(self.model, state_dict, "") + broadcast_dataset_rank0_model(self.model) return convert_tp = False diff --git a/paddlenlp/trainer/utils/helper.py b/paddlenlp/trainer/utils/helper.py index 7f0e87d0f9e2..419ff5e62dca 100644 --- a/paddlenlp/trainer/utils/helper.py +++ b/paddlenlp/trainer/utils/helper.py @@ -23,9 +23,9 @@ import paddle import paddle.distributed as dist from paddle.distributed import fleet +from paddle.distributed.parallel import sync_params_buffers from paddlenlp.utils.log import logger -from paddlenlp.utils.nested import nested_broadcast_tensor_with_empty # noqa: F401 from paddlenlp.utils.nested import ( nested_broadcast_tensor, nested_empty_tensor, @@ -311,19 +311,27 @@ def _broadcast_moe_optimizer_state(state_dict): return state_dict -def broadcast_dataset_rank0_model(state_dict): +def broadcast_dataset_rank0_model(model): if paddle.distributed.get_world_size() <= 1: - return state_dict + return logger.info("Start broadcast model in sharding group or data parallel group.") hcg = fleet.get_hybrid_communicate_group() sharding_group = hcg.get_sharding_parallel_group() dp_group = hcg.get_data_parallel_group() - if sharding_group.nranks > 1: - for k in state_dict.keys(): - dist.broadcast(state_dict[k], src=hcg.get_sharding_parallel_group_src_rank(), group=sharding_group) + sync_params_buffers( + model, + sharding_group, + hcg.get_sharding_parallel_group_src_rank(), + is_model_parallel=False, + fuse_params=False, + ) if dp_group.nranks > 1: - for k in state_dict.keys(): - dist.broadcast(state_dict[k], src=hcg.get_data_parallel_group_src_rank(), group=dp_group) - return state_dict + sync_params_buffers( + model, + dp_group, + hcg.get_data_parallel_group_src_rank(), + is_model_parallel=False, + fuse_params=False, + )