Skip to content

Commit

Permalink
[fea] moe support
Browse files Browse the repository at this point in the history
  • Loading branch information
bo-ke committed May 30, 2024
1 parent 0cd8fe7 commit 67981e5
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 32 deletions.
4 changes: 4 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -719,4 +719,8 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
Whether use flatten_param_grads method in optimizer,
only used on NPU devices.(default:False)
--use_expert_parallel
Whether to enable MoE (Mixture of Experts) expert parallel training.
(default: False)
```
64 changes: 42 additions & 22 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
from .utils import reshard as reshard_util
from .utils.helper import ( # nested_truncate,
broadcast_dp_optimizer,
broadcast_moe_optimizer,
distributed_concat,
distributed_file,
distributed_isfile,
Expand Down Expand Up @@ -930,22 +931,17 @@ def _inner_training_loop(
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
self.timers and self.timers("forward-backward").start()

dp_enabled = (
self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1
)
forbidden_no_sync = False
# stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API
# hybrid_parallel (tp or pp or sharding stage 1) should not no_sync
if self.args.use_hybrid_parallel:
forbidden_no_sync = True

availiable_no_sync = dp_enabled and not forbidden_no_sync

availiable_no_sync = hasattr(model, "no_sync")

Check warning on line 936 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L936

Added line #L936 was not covered by tests
is_no_sync = (
((step_control + 1) % args.gradient_accumulation_steps != 0)
and availiable_no_sync
and args._no_sync_in_gradient_accumulation
) or (args.recompute and availiable_no_sync)
(
((step_control + 1) % args.gradient_accumulation_steps != 0)
and args._no_sync_in_gradient_accumulation
)
or args.recompute
or args.use_expert_parallel
) and availiable_no_sync
# sharding
# stage1. the same as ddp
# stage2. manualy collect gradient on dp group
Expand All @@ -965,6 +961,14 @@ def _inner_training_loop(

tr_loss += tr_loss_step

def fused_allreduce_gradients_no_sync(paramlist, hcg):
paramlist = list(paramlist)
nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)]
moelist = [p for p in paramlist if getattr(p, "no_sync", False)]
if moelist and not self.args.use_expert_parallel:
logger.warning("found `no sync` param when `use_expert_parallel=False`")
fused_allreduce_gradients(nonmoe_list, hcg)

Check warning on line 970 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L964-L970

Added lines #L964 - L970 were not covered by tests

if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
Expand All @@ -983,12 +987,12 @@ def _inner_training_loop(

# Case 1: Use recompute and dp / sharding stage1,
# manualy collect gradient for dp.
if args.recompute and availiable_no_sync:
fused_allreduce_gradients(list(model.parameters()), None)
if (args.recompute or args.use_expert_parallel) and availiable_no_sync:
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

Check warning on line 991 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L990-L991

Added lines #L990 - L991 were not covered by tests

# Case 2: hack dp with master_grad
if dp_master_grad and not (args.recompute and availiable_no_sync):
fused_allreduce_gradients(list(model.parameters()), None)
elif dp_master_grad:
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

Check warning on line 995 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L994-L995

Added lines #L994 - L995 were not covered by tests

# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
Expand All @@ -1007,8 +1011,7 @@ def _inner_training_loop(
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)

if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)

fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg)

Check warning on line 1014 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1014

Added line #L1014 was not covered by tests
self.timers and self.timers("all-reduce").stop()
self.timers and self.timers("optimizer-step").start()

Expand All @@ -1028,6 +1031,8 @@ def _inner_training_loop(
)
optimizer_was_run = True
if self.do_grad_scaling:
if args.pipeline_parallel_degree > 1:
assert not self.args.use_expert_parallel, "pipline moe not work under fp16"

Check warning on line 1035 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1034-L1035

Added lines #L1034 - L1035 were not covered by tests
scale_before = paddle.assign(self.scaler._scale)
self.scaler.step(self.optimizer)
self.scaler.update()
Expand Down Expand Up @@ -2042,7 +2047,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,

model.train()
inputs = self._prepare_inputs(inputs)

with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)

Expand All @@ -2053,7 +2057,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
self.scaler.scale(loss).backward()
else:
loss.backward()

return loss.detach()

def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
Expand Down Expand Up @@ -2143,6 +2146,17 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
# For ckpt integrity
paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done"))

def _save_moe_weights(
self,
output_dir: Optional[str] = None,
merge_tensor_parallel: Optional[bool] = False,
):
# save moe optimizer and model state # TODO 默认为冗余存储

self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, optimizer_name))

Check warning on line 2158 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2156-L2158

Added lines #L2156 - L2158 were not covered by tests

def _save_checkpoint(self, model, metrics=None):
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
self.runtime_timer.start("checkpoint saving time")
Expand Down Expand Up @@ -2245,6 +2259,8 @@ def _save_checkpoint(self, model, metrics=None):
os.makedirs(output_dir, exist_ok=True)
paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth"))

if self.args.use_expert_parallel and self.args.data_parallel_rank > 0:
self._save_moe_weights(output_dir)

Check warning on line 2263 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2262-L2263

Added lines #L2262 - L2263 were not covered by tests
# Maybe delete some older checkpoints.
# For hybrid parallel training, the checkpoint files maybe on different node.
need_to_rotate_checkpoints = False
Expand Down Expand Up @@ -2476,7 +2492,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# broadcast optimizer state in dp group
if self.args.local_rank != -1:
dist.barrier()
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
if self.args.use_expert_parallel:
opt_state_dict = broadcast_moe_optimizer(opt_state_dict)

Check warning on line 2496 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2495-L2496

Added lines #L2495 - L2496 were not covered by tests
else:
if not self.args.should_load_sharding_stage1_model:
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)

Check warning on line 2499 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2498-L2499

Added lines #L2498 - L2499 were not covered by tests

if opt_state_dict is not None:
# Load in optimizer and scheduler states
Expand Down
25 changes: 24 additions & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,10 @@ class TrainingArguments:
default=False,
metadata={"help": "whether to run distributed training in auto parallel mode"},
)
use_expert_parallel: Optional[bool] = field(
default=False,
metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down Expand Up @@ -1149,6 +1153,8 @@ def is_segment_parallel_supported():
order = ["dp", "sharding", "pp", "sep", "mp"]
else:
order = ["dp", "sharding", "pp", "mp"]
if self.use_expert_parallel:
order = order[1:-1] + ["dp", "mp"]

Check warning on line 1157 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1156-L1157

Added lines #L1156 - L1157 were not covered by tests

if is_segment_parallel_supported():
hybrid_configs = {
Expand Down Expand Up @@ -1640,8 +1646,12 @@ def optimizer_name_suffix(self):
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.sharding_parallel_degree > 1:
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
if self.use_expert_parallel:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))

Check warning on line 1650 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1649-L1650

Added lines #L1649 - L1650 were not covered by tests
return "_".join(name)
else:
if self.use_expert_parallel:
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)

Check warning on line 1654 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1653-L1654

Added lines #L1653 - L1654 were not covered by tests
return None

@property
Expand All @@ -1652,12 +1662,16 @@ def weight_name_suffix(self):
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
if self.pipeline_parallel_degree > 1:
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.use_expert_parallel:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))

Check warning on line 1666 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1665-L1666

Added lines #L1665 - L1666 were not covered by tests
return "_".join(name)

else:
if self.use_expert_parallel:

Check warning on line 1670 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1670

Added line #L1670 was not covered by tests
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
return None

def sharded_name_suffix(self, shard_id=None, pp_id=None):
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):

Check warning on line 1674 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1674

Added line #L1674 was not covered by tests
if self.use_hybrid_parallel:
name = []
if self.tensor_parallel_degree > 1:
Expand All @@ -1672,8 +1686,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None):
shard_id = self.sharding_parallel_rank
assert isinstance(shard_id, int)
name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree))
if self.use_expert_parallel:
if moe_id is None:
moe_id = self.data_parallel_rank
assert isinstance(moe_id, int)
name.append(self._format_name("moe", moe_id, self.data_parallel_degree))

Check warning on line 1693 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1689-L1693

Added lines #L1689 - L1693 were not covered by tests
return "_".join(name)
else:
if self.use_expert_parallel:
if moe_id is None:

Check warning on line 1697 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1696-L1697

Added lines #L1696 - L1697 were not covered by tests
moe_id = self.data_parallel_rank
return self._format_name("moe", moe_id, self.data_parallel_degree)

Check warning on line 1699 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1699

Added line #L1699 was not covered by tests
return None

@property
Expand Down
56 changes: 56 additions & 0 deletions paddlenlp/trainer/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,59 @@ def broadcast_dp_optimizer(state_dict):
state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group)

return state_dict


def broadcast_moe_optimizer(state_dict):

try:
hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
src_rank = hcg.get_data_parallel_group_src_rank()
data_parallel_rank = hcg.get_data_parallel_rank()

Check warning on line 237 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L233-L237

Added lines #L233 - L237 were not covered by tests
# Don't broadcast optimizer for dp rank is 1.
if dp_group.nranks <= 1:
return state_dict
except:
dp_group = None
src_rank = 0
data_parallel_rank = 0

Check warning on line 244 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L239-L244

Added lines #L239 - L244 were not covered by tests

def _broadcast_moe_optimizer_state(state_dict):

Check warning on line 246 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L246

Added line #L246 was not covered by tests
# boardcast_keys
base_state_dict = {"master_weights": {}}
buf = [

Check warning on line 249 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L248-L249

Added lines #L248 - L249 were not covered by tests
{i: j.shape for i, j in state_dict.items() if i not in ["master_weights", "LR_Scheduler"]},
{i: j.shape for i, j in state_dict["master_weights"].items()},
{"LR_Scheduler": state_dict.get("LR_Scheduler", {})},
]

dist.broadcast_object_list(buf, src=src_rank, group=dp_group)

Check warning on line 255 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L255

Added line #L255 was not covered by tests
# logger.info(f"moe-optimizer-gather-keys{buf}")
for k, s in buf[0].items():
v = state_dict.get(k, paddle.zeros(s, "float32")).cuda()
v.name = k

Check warning on line 259 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L257-L259

Added lines #L257 - L259 were not covered by tests
# k = k.replace("_fp32_master_0", "")
dist.broadcast(v, src=src_rank, group=dp_group)
logger.info(f"broadcast moe optimizer {k} from {src_rank}")
base_state_dict[k] = v.cpu()
for k, s in buf[1].items():
v = state_dict["master_weights"].get(k, paddle.zeros(s, "float32")).cuda()
v.name = k
dist.broadcast(v, src=src_rank, group=dp_group)
logger.info(f"broadcast moe optimizer-master_weights {k} from {src_rank}")
base_state_dict["master_weights"][k] = v.cpu()
base_state_dict.update(buf[2])
return base_state_dict

Check warning on line 271 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L261-L271

Added lines #L261 - L271 were not covered by tests

base_state_dict = _broadcast_moe_optimizer_state(state_dict)
if data_parallel_rank > 0:
master_weight = state_dict.pop("master_weights", {})
base_state_dict.update(state_dict)
if master_weight:
if "master_weights" in base_state_dict:
base_state_dict["master_weights"].update(master_weight)

Check warning on line 279 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L273-L279

Added lines #L273 - L279 were not covered by tests
else:
base_state_dict["master_weights"] = master_weight
state_dict = base_state_dict
del base_state_dict
return state_dict

Check warning on line 284 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L281-L284

Added lines #L281 - L284 were not covered by tests
22 changes: 21 additions & 1 deletion paddlenlp/trainer/utils/reshard/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,16 @@ def _opt_name_to_tname(tensor_names, opt_names):
all_names.extend(opt_names)
all_names.sort()
pre_t_name = ""
suffix = [

Check warning on line 269 in paddlenlp/trainer/utils/reshard/common.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/common.py#L269

Added line #L269 was not covered by tests
"_fp32_master_0_beta1_pow_acc_0",
"_fp32_master_0_beta2_pow_acc_0",
"_fp32_master_0_moment1_0",
"_fp32_master_0_moment2_0",
"_beta1_pow_acc_0",
"_beta2_pow_acc_0",
"_moment1_0",
"_moment2_0",
]
opt_to_t = {}
for n in all_names:
if n in tensor_names:
Expand All @@ -274,6 +284,16 @@ def _opt_name_to_tname(tensor_names, opt_names):
else:
assert pre_t_name
opt_to_t[n] = pre_t_name

for t in opt_names:
_find = False
for s in suffix:
if t.endswith(s):
logger.info(f"{t}-{t[:-len(s)]}--{t[:-len(s)] in tensor_names}")
opt_to_t[t] = t[: -len(s)]
_find = True
break
assert _find

Check warning on line 296 in paddlenlp/trainer/utils/reshard/common.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/common.py#L288-L296

Added lines #L288 - L296 were not covered by tests
return opt_to_t

if structure_name_mapping is not None:
Expand All @@ -291,7 +311,7 @@ def _opt_name_to_tname(tensor_names, opt_names):
(self._model_weights, model_weights_tmp) = (model_weights_tmp, self._model_weights)
for k in list(model_weights_tmp.keys()):
t_name = structure_name_mapping[k]
self._model_weights[(k, t_name)] = model_weights_tmp[k].cpu()
self._model_weights[(k, t_name)] = paddle.to_tensor(model_weights_tmp[k]).cpu()

Check warning on line 314 in paddlenlp/trainer/utils/reshard/common.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/common.py#L314

Added line #L314 was not covered by tests
del model_weights_tmp[k]

# opt
Expand Down
24 changes: 16 additions & 8 deletions paddlenlp/trainer/utils/sharding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,14 @@ def filter_sharded_params(state_dict, optimizer, sharding_group):
if reshard_util.get_sharding_strategy(optimizer) == reshard_util.SHARDING_STRATEGY_V1:
optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizer)
for (k, v) in state_dict.items():
assert v.name in optimizer._param2rank
sharded_rank = optimizer._param2rank[v.name]
if sharded_rank != sharding_rank:
continue
filtered_state_dict[k] = v
if v.name in optimizer._param2rank:
sharded_rank = optimizer._param2rank[v.name]
if sharded_rank != sharding_rank:
continue
filtered_state_dict[k] = v

Check warning on line 74 in paddlenlp/trainer/utils/sharding_io.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/sharding_io.py#L70-L74

Added lines #L70 - L74 were not covered by tests
else:
if sharding_rank == 0:
filtered_state_dict[k] = v

Check warning on line 77 in paddlenlp/trainer/utils/sharding_io.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/sharding_io.py#L76-L77

Added lines #L76 - L77 were not covered by tests
else:
optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizerV2)
parameters = optimizer._parameter_list
Expand Down Expand Up @@ -352,7 +355,7 @@ def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=
)
logger.info(
"param_names_in_master_weights len:{}, bf16 state_dict len:{}, :{}".format(
len(param_names_in_master_weights), len(state_dict), state_dict
len(param_names_in_master_weights), len(state_dict), state_dict.keys()
)
)
return state_dict, config_to_save, weight_name_suffix
Expand Down Expand Up @@ -444,12 +447,17 @@ def filter_func(name):

master_weights = reshard_util.all_gather_state_dict(master_weights, filter_func, self.sharding_group)
model_state_dict = self.model.state_dict()
logger.info(f"state-dict-keys: {state_dict.keys()}, nums: {len(state_dict.keys())}")

Check warning on line 450 in paddlenlp/trainer/utils/sharding_io.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/sharding_io.py#L450

Added line #L450 was not covered by tests
logger.info("before recover, model_state_dict number: {}".format(len(model_state_dict)))
for key, param in model_state_dict.items():
if param.name in master_weights:
assert param.shape == master_weights[param.name].shape
paddle.assign(master_weights[param.name].cuda(), model_state_dict[key])

paddle.assign(paddle.cast(master_weights[param.name].cuda(), paddle.bfloat16), model_state_dict[key])
elif key in state_dict:
logger.info(f"key: {key} is in state_dict, but not in master_weights")
paddle.assign(state_dict[key], model_state_dict[key])

Check warning on line 458 in paddlenlp/trainer/utils/sharding_io.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/sharding_io.py#L455-L458

Added lines #L455 - L458 were not covered by tests
else:
logger.info(f"key: {key} is not in state_dict and master_weights")

Check warning on line 460 in paddlenlp/trainer/utils/sharding_io.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/sharding_io.py#L460

Added line #L460 was not covered by tests
logger.info("after recover, casted model_state_dict number: {}".format(len(model_state_dict)))
state_dict.update(model_state_dict)
return state_dict
Expand Down

0 comments on commit 67981e5

Please sign in to comment.