Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 14, 2024
1 parent 2097916 commit d1d7bc1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
9 changes: 2 additions & 7 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
from ..transformers.model_utils import (
PretrainedModel,
_add_variant,
_load_state_dict_into_model,
load_sharded_checkpoint,
unwrap_model,
)
Expand Down Expand Up @@ -1164,9 +1163,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
safe_serialization=True,
)
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
state_dict = broadcast_dataset_rank0_model(self.model.state_dict())
if self.args.dataset_rank > 0:
_load_state_dict_into_model(self.model, state_dict, "")
broadcast_dataset_rank0_model(self.model)

Check warning on line 1166 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1165-L1166

Added lines #L1165 - L1166 were not covered by tests
else:
weight_name = PADDLE_WEIGHTS_NAME
best_model_path = os.path.join(
Expand Down Expand Up @@ -1210,9 +1207,7 @@ def _load_best_model_from_peft_checkpoint(self):
safe_serialization=True,
)
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
state_dict = broadcast_dataset_rank0_model(self.model.get_trainable_state_dict())
if self.args.dataset_rank > 0:
_load_state_dict_into_model(self.model, state_dict, "")
broadcast_dataset_rank0_model(self.model)

Check warning on line 1210 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1209-L1210

Added lines #L1209 - L1210 were not covered by tests
return

convert_tp = False
Expand Down
26 changes: 17 additions & 9 deletions paddlenlp/trainer/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle.distributed.parallel import sync_params_buffers

from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import nested_broadcast_tensor_with_empty # noqa: F401
from paddlenlp.utils.nested import (
nested_broadcast_tensor,
nested_empty_tensor,
Expand Down Expand Up @@ -311,19 +311,27 @@ def _broadcast_moe_optimizer_state(state_dict):
return state_dict


def broadcast_dataset_rank0_model(state_dict):
def broadcast_dataset_rank0_model(model):
if paddle.distributed.get_world_size() <= 1:
return state_dict
return

Check warning on line 316 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L315-L316

Added lines #L315 - L316 were not covered by tests

logger.info("Start broadcast model in sharding group or data parallel group.")
hcg = fleet.get_hybrid_communicate_group()
sharding_group = hcg.get_sharding_parallel_group()
dp_group = hcg.get_data_parallel_group()

if sharding_group.nranks > 1:
for k in state_dict.keys():
dist.broadcast(state_dict[k], src=hcg.get_sharding_parallel_group_src_rank(), group=sharding_group)
sync_params_buffers(

Check warning on line 323 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L318-L323

Added lines #L318 - L323 were not covered by tests
model,
sharding_group,
hcg.get_sharding_parallel_group_src_rank(),
is_model_parallel=False,
fuse_params=False,
)
if dp_group.nranks > 1:
for k in state_dict.keys():
dist.broadcast(state_dict[k], src=hcg.get_data_parallel_group_src_rank(), group=dp_group)
return state_dict
sync_params_buffers(

Check warning on line 331 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L330-L331

Added lines #L330 - L331 were not covered by tests
model,
dp_group,
hcg.get_data_parallel_group_src_rank(),
is_model_parallel=False,
fuse_params=False,
)

0 comments on commit d1d7bc1

Please sign in to comment.