Skip to content

Commit

Permalink
xpu save support (PaddlePaddle#8853)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Aug 1, 2024
1 parent 84d7845 commit 5c8b09c
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 8 deletions.
2 changes: 2 additions & 0 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,8 @@ def sample(
probs = TopPProcess(probs, top_p, min_tokens_to_keep)
if paddle.device.is_compiled_with_custom_device("gcu"):
probs = paddle.cast(probs, "float32")
if paddle.device.is_compiled_with_xpu():
probs = paddle.cast(probs, "float32")

# multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
next_tokens = paddle.multinomial(probs)
Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@
load_state_dict,
)
from ...transformers.utils import get_checkpoint_shard_files, weight_name_suffix
from ...utils.distributed import distributed_gather
from ...utils.distributed import distributed_allgather, distributed_gather
from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME
from ...utils.log import logger
from ...utils.tools import get_env_device
from .lora_config import LoRAConfig

try:
Expand Down Expand Up @@ -285,7 +286,10 @@ def _merge_trainable_tensor_parallel(self, trainable_state_dict):
for key in trainable_state_dict:
tensor = trainable_state_dict[key]
if key in trainable_name_action_mappings:
ret = distributed_gather(tensor, group=mp_group, offload=True)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=mp_group, offload=True)
else:
ret = distributed_gather(tensor, group=mp_group, offload=True)
action = trainable_name_action_mappings[key]
if key in self.lora_split_mapping and not self.lora_split_mapping[key] and "_scale" in key and is_dst:
ret = paddle.to_tensor(ret)
Expand Down
13 changes: 10 additions & 3 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
get_checkpoint_shard_files,
is_safetensors_available,
)
from paddlenlp.utils.distributed import distributed_gather
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
from paddlenlp.utils.env import (
LORA_WEIGHTS_NAME,
PADDLE_MASTER_WEIGHTS_INDEX_NAME,
Expand All @@ -64,6 +64,7 @@
)
from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import nested_copy, nested_copy_place
from paddlenlp.utils.tools import get_env_device

if is_safetensors_available():
# from safetensors import safe_open
Expand Down Expand Up @@ -1753,7 +1754,10 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
key = filter_keys[i]
tensor = state_dict[key]
if key in tp_actions:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=tp_group, offload=False)
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions.pop(key)
tensor = action(ret) if is_dst else None
else:
Expand Down Expand Up @@ -1790,7 +1794,10 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
if tensor.numel().item() == 1:
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=tp_group, offload=False)
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions[model_key]
tensor = action(ret) if is_dst else None
else:
Expand Down
9 changes: 7 additions & 2 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from paddle import Tensor
from paddle.nn import Layer

from paddlenlp.utils.distributed import distributed_gather
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
from paddlenlp.utils.env import CONFIG_NAME, PADDLE_WEIGHTS_NAME, PYTORCH_WEIGHTS_NAME
from paddlenlp.utils.import_utils import (
is_package_available,
Expand All @@ -50,6 +50,8 @@
if TYPE_CHECKING:
from paddlenlp.transformers import PretrainedConfig, PretrainedModel

from paddlenlp.utils.tools import get_env_device

from ..utils import device_guard

# the type hinting for pytorch model & layer & tensor
Expand Down Expand Up @@ -1269,7 +1271,10 @@ def merge_tensor_parallel(cls, state_dict, config) -> None:
for key in state_dict.keys():
tensor = state_dict[key]
if key in name_action_mappings:
ret = distributed_gather(tensor, group=mp_group, offload=True)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=mp_group, offload=True)
else:
ret = distributed_gather(tensor, group=mp_group, offload=True)
action = name_action_mappings.pop(key)
tensor = action(ret) if is_dst else None
else:
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def distributed_allgather(tensor: Any, group=None, offload=False):
x.reshape_(origin_shape)

else:
distributed.all_gather(output_tensors, tensor)
distributed.all_gather(output_tensors, tensor, group=group)

return output_tensors

Expand Down

0 comments on commit 5c8b09c

Please sign in to comment.