Skip to content

Commit

Permalink
xpu use allgather
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Jul 2, 2024
1 parent a53477c commit b18ca2e
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from paddle import Tensor
from paddle.nn import Layer

from paddlenlp.utils.distributed import distributed_gather
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
from paddlenlp.utils.env import CONFIG_NAME, PADDLE_WEIGHTS_NAME, PYTORCH_WEIGHTS_NAME
from paddlenlp.utils.import_utils import (
is_package_available,
Expand All @@ -46,6 +46,7 @@
)
from paddlenlp.utils.log import logger
from paddlenlp.utils.serialization import load_torch
from paddlenlp.utils.tools import get_env_device

if TYPE_CHECKING:
from paddlenlp.transformers import PretrainedConfig, PretrainedModel
Expand Down Expand Up @@ -1269,7 +1270,10 @@ def merge_tensor_parallel(cls, state_dict, config) -> None:
for key in state_dict.keys():
tensor = state_dict[key]
if key in name_action_mappings:
ret = distributed_gather(tensor, group=mp_group, offload=True)
if get_env_device == "xpu":
ret = distributed_allgather(tensor, group=mp_group, offload=True)

Check warning on line 1274 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1273-L1274

Added lines #L1273 - L1274 were not covered by tests
else:
ret = distributed_gather(tensor, group=mp_group, offload=True)

Check warning on line 1276 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1276

Added line #L1276 was not covered by tests
action = name_action_mappings.pop(key)
tensor = action(ret) if is_dst else None
else:
Expand Down

0 comments on commit b18ca2e

Please sign in to comment.