Skip to content

Commit

Permalink
fix compatible with npu. (#8409)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHUI authored May 9, 2024
1 parent 2619f17 commit 9146c1e
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@
async_save_queue = []


DEST_PLACE = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
DEST_PLACE = paddle.CUDAPinnedPlace()


class UnifiedCheckpointOption(ExplicitEnum):
"""
"- skip_save_model_weight: do not save model weights when the masters weight exist\n"
Expand Down Expand Up @@ -1746,7 +1751,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
action = tp_actions.pop(key)
tensor = action(ret) if is_dst else None
else:
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None

if is_dst:
state_dict_to_save[key] = tensor
Expand Down Expand Up @@ -1777,15 +1782,13 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
if model_key in tp_actions:
# for example: beta1, beta2
if tensor.numel().item() == 1:
tensor = (
tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
) # Need broadcast when loaded
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions[model_key]
tensor = action(ret) if is_dst else None
else:
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None

if is_dst:
state_dict_to_save[filter_keys[i]] = tensor
Expand Down

0 comments on commit 9146c1e

Please sign in to comment.