Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick some changes from incubate branch #8862

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,17 @@
g_cpu_optimizer_state_dict = {}


def _save_func(obj, path, saved_signal_path, protocol):
def _save_func(obj, name_mapping, path, saved_signal_path, protocol):
if isinstance(obj, dict):
for k, v in obj.items():
if k == "master_weights" and isinstance(v, dict):
for kk, vv in v.items():
if isinstance(vv, paddle.Tensor):
vv.name = name_mapping["master_weights"][kk]
else:
if k in name_mapping and isinstance(v, paddle.Tensor):
v.name = name_mapping[k]

paddle.save(obj, path, protocol)
# dump savd_siganl
with open(saved_signal_path, mode="w+") as f:
Expand Down Expand Up @@ -228,17 +238,18 @@ def clear_async_save_task_queue():
def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol=4):
global g_cpu_optimizer_state_dict
g_cpu_optimizer_state_dict.clear()
name_mapping = {"master_weights": {}}
for k, v in optimizer_state_dict.items():
if k == "master_weights":
g_cpu_optimizer_state_dict[k] = {}
for kk, vv in v.items():
tensor_name = vv.name
g_cpu_optimizer_state_dict[k][kk] = vv.pin_memory()
g_cpu_optimizer_state_dict[k][kk].name = tensor_name
name_mapping[k][kk] = vv.name
elif k == "LR_Scheduler":
g_cpu_optimizer_state_dict[k] = copy.deepcopy(v)
else:
g_cpu_optimizer_state_dict[k] = v.pin_memory()
name_mapping[k] = v.name
paddle.device.synchronize()
clear_async_save_task_queue()

Expand All @@ -248,7 +259,9 @@ def async_save_optimizer(optimizer_state_dict, path, saved_signal_path, protocol
def start_process():
nonlocal attempt
try:
p = ctx.Process(target=_save_func, args=(g_cpu_optimizer_state_dict, path, saved_signal_path, protocol))
p = ctx.Process(
target=_save_func, args=(g_cpu_optimizer_state_dict, name_mapping, path, saved_signal_path, protocol)
)
p.start()
return p
except Exception as e:
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,15 +1171,23 @@ def split_parallel_config(parallel_config):
# sync_param_name = [""] matches any parameter name.
# If sync_param, sync_grad and sync_moment are not set, the default value in Paddle is :
# sync_param = True, sync_grad = False, sync_moment = False, sync_param_name = ["embedding", "layer_norm", ".b_"].

if sync_param or sync_grad or sync_moment:
logger.info("setting sync_param_name")
strategy.sync_param_name = [""]

if sync_param:
logger.info("setting sync_param")
strategy.hybrid_configs["mp_configs"].sync_param = True
strategy.hybrid_configs["mp_configs"].sync_param_name = [""]

if sync_grad:
logger.info("setting sync_grad")
strategy.hybrid_configs["mp_configs"].sync_grad = True
strategy.hybrid_configs["mp_configs"].sync_grad_name = [""]

if sync_moment:
logger.info("setting sync_moment")
strategy.hybrid_configs["mp_configs"].sync_moment = True
strategy.hybrid_configs["mp_configs"].sync_moment_name = [""]

except:
warnings.warn(
"The enable_mp_async_allreduce, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add are not supported "
Expand Down
61 changes: 59 additions & 2 deletions paddlenlp/trainer/utils/reshard/sharding_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@

import numpy as np
import paddle
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import (
HybridParallelOptimizer,
)
from paddle.distributed.fleet.model import PipelineParallel

from paddlenlp.utils.log import logger

from ....transformers.model_utils import unwrap_optimizer

try:
Expand All @@ -29,6 +32,9 @@
DygraphShardingOptimizerV2 = None


from paddle.distributed.communication.reduce import ReduceOp


def shard(node_model_state, model, optimizer, hcg):
assert DygraphShardingOptimizerV2 is not None
group = hcg.get_sharding_parallel_group()
Expand Down Expand Up @@ -137,7 +143,7 @@ def slice_tensor(tensor, begin, end):
return tensor[begin:end]


def collect_split_info(optimizer, model):
def collect_split_info(optimizer, model, only_return_lengths=False):
split_infos = {}

def gather_infos(comm_buffer):
Expand All @@ -146,7 +152,13 @@ def gather_infos(comm_buffer):
padded_size = v._padded_size
buffer_size = v._param_buffer._numel()
has_slice_grad = v._slice_grad is not None
split_infos[k] = (index, padded_size, buffer_size, has_slice_grad)
if only_return_lengths:
if v._param_begin < v._param_end:
split_infos[k] = v._param_end - v._param_begin
else:
split_infos[k] = None
else:
split_infos[k] = (index, padded_size, buffer_size, has_slice_grad)

if isinstance(model, PipelineParallel) and model._sharding_comm_overlap > 0:
optimizer = unwrap_optimizer(optimizer, HybridParallelOptimizer)
Expand All @@ -167,6 +179,51 @@ def gather_infos(comm_buffer):
return split_infos


def is_matched_optimizer_state_dict(opt_state_dict, optimizer, model, hcg=None, need_allgather=True):
split_infos = collect_split_info(optimizer, model, only_return_lengths=True)
master_weights = opt_state_dict.get("master_weights", None)

def get_matched_length(name):
if master_weights and name in master_weights:
tensor = master_weights[name]
else:
moment_name = name + "_moment1_0"
if moment_name not in opt_state_dict:
return None

tensor = opt_state_dict[moment_name]
if isinstance(tensor, (list, tuple)):
assert len(tensor) == 2, tensor
assert isinstance(tensor[0], str), tensor[0]
tensor = tensor[1]
shape = tensor.shape
assert len(shape) == 1, shape
length = shape[0]
return length

is_matched = 1
for k, length in split_infos.items():
matched_length = get_matched_length(k)
if length != matched_length:
is_matched = 0
break

if need_allgather:
if hcg is None:
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_sharding_parallel_group()
if group is not None and group.nranks > 1:
x = paddle.to_tensor([is_matched], dtype=paddle.int32)
paddle.distributed.stream.all_reduce(x, op=ReduceOp.MIN, group=group, sync_op=True, use_calc_stream=True)
global_is_matched = int(x.numpy()[0])
else:
global_is_matched = is_matched

global_is_matched = True if global_is_matched else False
logger.info(f"Sharding reshard checkpoint: local_match = {is_matched} , global_match = {global_is_matched}")
return global_is_matched


def is_bata(name):
if "_beta1_pow_acc_" in name:
return True
Expand Down
50 changes: 41 additions & 9 deletions paddlenlp/trainer/utils/sharding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from paddlenlp.utils.log import logger

from . import reshard as reshard_util
from .reshard import SHARDING_STRATEGY_V1, pp_reshard
from .reshard import SHARDING_STRATEGY_V1, SHARDING_STRATEGY_V2, pp_reshard

# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
Expand Down Expand Up @@ -204,10 +204,21 @@ def _load_optimizer_state_of_one_shard(self, checkpoint, base_opt_name, optimize
path = os.path.join(checkpoint, optimizer_name)
logger.info(f"load optimizer state from {path}")
if os.path.isfile(path):
return paddlenlp_load(path, map_location="cpu")
return self._modify_ckpt_for_compatibility(paddlenlp_load(path, map_location="cpu"))
logger.info(f"{path} not exists")
return None

def _modify_ckpt_for_compatibility(self, ckpt):
master_weights = ckpt.get("master_weights", None)
if master_weights:
for k, v in master_weights.items():
assert isinstance(v, paddle.Tensor), v
if not v.name.startswith(k):
new_name = k + "_fp32_master_0"
logger.info(f"Modify master weights {v.name} -> {new_name}")
v.name = new_name
return ckpt

def _need_reshard(self, checkpoint):
if self._need_reshard_pp(checkpoint):
return True
Expand Down Expand Up @@ -253,10 +264,6 @@ def _need_reshard_pp(self, checkpoint):
def load_optimizer_state_with_reshard(self, checkpoint, base_opt_name, model_wrapped):
"""load state_dict of multiple shard from_checkpoint, Only load model state dict."""

if not self._need_reshard(checkpoint):
logger.info("do not need reshard")
return self._load_optimizer_state_of_one_shard(checkpoint, base_opt_name, self.args.optimizer_name_suffix)
logger.info("reshard optimizer state")
parallel_config = self._load_distributed_strategy(checkpoint)
sharding_meta = self._load_sharding_meta(checkpoint, 0)
pp_degree = parallel_config["pp_degree"]
Expand All @@ -276,16 +283,41 @@ def load_optimizer_state_with_reshard(self, checkpoint, base_opt_name, model_wra
cur_sharding_degree = self.args.sharding_parallel_degree
cur_sharding_strategy = reshard_util.get_sharding_strategy(self.optimizer)

if not self._need_reshard(checkpoint):
one_shard_opt_state_dict = self._load_optimizer_state_of_one_shard(
checkpoint, base_opt_name, self.args.optimizer_name_suffix
)

if sharding_strategy == SHARDING_STRATEGY_V2 and cur_sharding_strategy == SHARDING_STRATEGY_V2:
is_matched = reshard_util.sharding_v2.is_matched_optimizer_state_dict(
one_shard_opt_state_dict, self.optimizer, model_wrapped
)
else:
is_matched = True

if is_matched:
logger.info("do not need reshard")
return one_shard_opt_state_dict
else:
one_shard_opt_state_dict = None

logger.info("reshard optimizer state")

def load_model_slices():
model_state = reshard_util.NodeModelState()
for j in range(self.args.pipeline_parallel_rank, pp_degree, cur_pp_degree):
cur_sharding_meta = self._load_sharding_meta(checkpoint, j)
assert "structure_name_mapping" in cur_sharding_meta
structure_name_map = cur_sharding_meta["structure_name_mapping"]
for i in range(self.args.sharding_parallel_rank, sharding_degree, cur_sharding_degree):
tmp = self._load_optimizer_state_of_one_shard(
checkpoint, base_opt_name, self.args.sharded_name_suffix(i, j)
)
sharded_name_suffix = self.args.sharded_name_suffix(i, j)
if one_shard_opt_state_dict is None:
tmp = self._load_optimizer_state_of_one_shard(checkpoint, base_opt_name, sharded_name_suffix)
else:
assert (
self.args.optimizer_name_suffix == sharded_name_suffix
), f"{self.args.optimizer_name_suffix} vs {sharded_name_suffix}"
tmp = one_shard_opt_state_dict
node_model_state_tmp = reshard_util.NodeModelState()
node_model_state_tmp.add_opts(tmp)
node_model_state_tmp.pack_keys(structure_name_map)
Expand Down
Loading