Skip to content

Commit

Permalink
xpu shard load support
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Aug 5, 2024
1 parent 5c8b09c commit fda78db
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions paddlenlp/trainer/utils/sharding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from paddlenlp.transformers.utils import paddlenlp_load
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device

from . import reshard as reshard_util
from .reshard import SHARDING_STRATEGY_V1, SHARDING_STRATEGY_V2, pp_reshard
Expand Down Expand Up @@ -457,7 +458,7 @@ def _recover_params_from_master_weights(self, state_dict, opt_state_dict=None):
# cast to before
for (k, v) in tmp.items():
name = v.name
master_weights[k] = paddle.cast(v.cuda(), paddle.bfloat16).cpu()
master_weights[k] = paddle.cast(v.to(get_env_device()), paddle.bfloat16).cpu()
master_weights[k].name = name

structure_name_map = {k: v.name for (k, v) in self.model.state_dict().items()}
Expand Down Expand Up @@ -488,7 +489,10 @@ def filter_func(name):
for key, param in model_state_dict.items():
if param.name in master_weights:
assert param.shape == master_weights[param.name].shape
paddle.assign(paddle.cast(master_weights[param.name].cuda(), paddle.bfloat16), model_state_dict[key])
paddle.assign(
paddle.cast(master_weights[param.name].to(get_env_device()), paddle.bfloat16),
model_state_dict[key],
)
elif key in state_dict:
logger.info(f"key: {key} is in state_dict, but not in master_weights")
paddle.assign(state_dict[key], model_state_dict[key])
Expand Down

0 comments on commit fda78db

Please sign in to comment.