Skip to content

Commit

Permalink
[fea] moe support
Browse files Browse the repository at this point in the history
  • Loading branch information
bo-ke committed May 28, 2024
1 parent 0cd8fe7 commit 2851da8
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 22 deletions.
51 changes: 40 additions & 11 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
from .utils import reshard as reshard_util
from .utils.helper import ( # nested_truncate,
broadcast_dp_optimizer,
broadcast_moe_optimizer,
distributed_concat,
distributed_file,
distributed_isfile,
Expand Down Expand Up @@ -945,7 +946,8 @@ def _inner_training_loop(
((step_control + 1) % args.gradient_accumulation_steps != 0)
and availiable_no_sync
and args._no_sync_in_gradient_accumulation
) or (args.recompute and availiable_no_sync)
) or (args.recompute and availiable_no_sync
) or (args.use_moe and availiable_no_sync)
# sharding
# stage1. the same as ddp
# stage2. manualy collect gradient on dp group
Expand All @@ -965,6 +967,14 @@ def _inner_training_loop(

tr_loss += tr_loss_step

def fused_allreduce_gradients_no_sync(paramlist, hcg):
paramlist = list(paramlist)
nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)]
moelist = [p for p in paramlist if getattr(p, "no_sync", False)]
if moelist and not self.args.use_moe:
logger.warning("found `no sync` param when `use_moe=False`")
fused_allreduce_gradients(nonmoe_list, hcg)

Check warning on line 976 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L970-L976

Added lines #L970 - L976 were not covered by tests

if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
Expand All @@ -983,12 +993,12 @@ def _inner_training_loop(

# Case 1: Use recompute and dp / sharding stage1,
# manualy collect gradient for dp.
if args.recompute and availiable_no_sync:
fused_allreduce_gradients(list(model.parameters()), None)
if (args.recompute or args.use_moe) and availiable_no_sync:
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

Check warning on line 997 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L996-L997

Added lines #L996 - L997 were not covered by tests

# Case 2: hack dp with master_grad
if dp_master_grad and not (args.recompute and availiable_no_sync):
fused_allreduce_gradients(list(model.parameters()), None)
elif dp_master_grad:
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

Check warning on line 1001 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1000-L1001

Added lines #L1000 - L1001 were not covered by tests

# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
Expand All @@ -1007,8 +1017,7 @@ def _inner_training_loop(
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)

if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)

fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg)

Check warning on line 1020 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1020

Added line #L1020 was not covered by tests
self.timers and self.timers("all-reduce").stop()
self.timers and self.timers("optimizer-step").start()

Expand All @@ -1028,7 +1037,9 @@ def _inner_training_loop(
)
optimizer_was_run = True
if self.do_grad_scaling:
scale_before = paddle.assign(self.scaler._scale)
if args.pipeline_parallel_degree > 1:
assert not self.args.use_moe, "pipline moe not work under fp16"
scale_before = self.scaler._scale.numpy()

Check warning on line 1042 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1040-L1042

Added lines #L1040 - L1042 were not covered by tests
self.scaler.step(self.optimizer)
self.scaler.update()
scale_after = self.scaler._scale
Expand Down Expand Up @@ -2042,7 +2053,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,

model.train()
inputs = self._prepare_inputs(inputs)

self.timers and self.timers(f"forward-acc-{self._cur_acc_step}").start()

Check warning on line 2056 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2056

Added line #L2056 was not covered by tests
with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)

Expand All @@ -2053,7 +2064,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
self.scaler.scale(loss).backward()
else:
loss.backward()

self.timers and self.timers(f"backward-acc-{self._cur_acc_step}").stop()

Check warning on line 2067 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2067

Added line #L2067 was not covered by tests
return loss.detach()

def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
Expand Down Expand Up @@ -2143,6 +2154,19 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
# For ckpt integrity
paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done"))

def _save_moe_weights(
self,
output_dir: Optional[str] = None,
merge_tensor_parallel: Optional[bool] = False,):
# save moe optimizer and model state # TODO 默认为冗余存储

self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}")
paddle.save(self.optimizer.state_dict(), os.path.join(output_dir, optimizer_name))
with open(saved_signal_path, mode="w+") as f:
f.write("1")

Check warning on line 2168 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2163-L2168

Added lines #L2163 - L2168 were not covered by tests

def _save_checkpoint(self, model, metrics=None):
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
self.runtime_timer.start("checkpoint saving time")
Expand Down Expand Up @@ -2245,6 +2269,8 @@ def _save_checkpoint(self, model, metrics=None):
os.makedirs(output_dir, exist_ok=True)
paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth"))

if self.args.use_moe and self.args.data_parallel_rank > 0:
self._save_moe_weights(output_dir)

Check warning on line 2273 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2272-L2273

Added lines #L2272 - L2273 were not covered by tests
# Maybe delete some older checkpoints.
# For hybrid parallel training, the checkpoint files maybe on different node.
need_to_rotate_checkpoints = False
Expand Down Expand Up @@ -2476,7 +2502,10 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# broadcast optimizer state in dp group
if self.args.local_rank != -1:
dist.barrier()
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
if not self.args.use_moe:
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)

Check warning on line 2506 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2505-L2506

Added lines #L2505 - L2506 were not covered by tests
else:
opt_state_dict = broadcast_moe_optimizer(opt_state_dict)

Check warning on line 2508 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2508

Added line #L2508 was not covered by tests

if opt_state_dict is not None:
# Load in optimizer and scheduler states
Expand Down
25 changes: 24 additions & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,10 @@ class TrainingArguments:
default=False,
metadata={"help": "whether to run distributed training in auto parallel mode"},
)
use_moe: Optional[bool] = field(
default=False,
metadata={"help": "开启moe训练"},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down Expand Up @@ -1149,6 +1153,8 @@ def is_segment_parallel_supported():
order = ["dp", "sharding", "pp", "sep", "mp"]
else:
order = ["dp", "sharding", "pp", "mp"]
if self.use_moe:
order = order[1: -1] + ["dp", "mp"]

Check warning on line 1157 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1156-L1157

Added lines #L1156 - L1157 were not covered by tests

if is_segment_parallel_supported():
hybrid_configs = {
Expand Down Expand Up @@ -1640,8 +1646,12 @@ def optimizer_name_suffix(self):
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.sharding_parallel_degree > 1:
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
if self.use_moe:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))

Check warning on line 1650 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1649-L1650

Added lines #L1649 - L1650 were not covered by tests
return "_".join(name)
else:
if self.use_moe:
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)

Check warning on line 1654 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1653-L1654

Added lines #L1653 - L1654 were not covered by tests
return None

@property
Expand All @@ -1652,12 +1662,16 @@ def weight_name_suffix(self):
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
if self.pipeline_parallel_degree > 1:
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.use_moe:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))

Check warning on line 1666 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1665-L1666

Added lines #L1665 - L1666 were not covered by tests
return "_".join(name)

else:
if self.use_moe:

Check warning on line 1670 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1670

Added line #L1670 was not covered by tests
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
return None

def sharded_name_suffix(self, shard_id=None, pp_id=None):
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):

Check warning on line 1674 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1674

Added line #L1674 was not covered by tests
if self.use_hybrid_parallel:
name = []
if self.tensor_parallel_degree > 1:
Expand All @@ -1672,8 +1686,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None):
shard_id = self.sharding_parallel_rank
assert isinstance(shard_id, int)
name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree))
if self.use_moe:
if moe_id is None:
moe_id = self.data_parallel_rank
assert isinstance(moe_id, int)
name.append(self._format_name("moe", moe_id, self.data_parallel_degree))

Check warning on line 1693 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1689-L1693

Added lines #L1689 - L1693 were not covered by tests
return "_".join(name)
else:
if self.use_moe:
if moe_id is None:

Check warning on line 1697 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1696-L1697

Added lines #L1696 - L1697 were not covered by tests
moe_id = self.data_parallel_rank
return self._format_name("moe", moe_id, self.data_parallel_degree)

Check warning on line 1699 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1699

Added line #L1699 was not covered by tests
return None

@property
Expand Down
107 changes: 106 additions & 1 deletion paddlenlp/trainer/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import os
from typing import Any, Optional

import copy
import numpy as np
import paddle
import paddle.distributed as dist
Expand Down Expand Up @@ -226,3 +226,108 @@ def broadcast_dp_optimizer(state_dict):
state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group)

return state_dict

# def broadcast_moe_optimizer(state_dict):
# if paddle.distributed.get_world_size() <= 1:
# return state_dict

# logger.info("Start broadcast optimizer in MoE(data) parallel group.")
# try:
# hcg = fleet.get_hybrid_communicate_group()
# dp_group = hcg.get_data_parallel_group()
# src_rank = hcg.get_data_parallel_group_src_rank()
# process_rank = paddle.distributed.get_rank()
# # Don't broadcast optimizer for dp rank is 1.
# if dp_group.nranks <= 1:
# return state_dict
# except:
# dp_group = None
# src_rank = 0
# process_rank = paddle.distributed.get_rank()

# if process_rank == src_rank:
# if state_dict is None:
# logger.warning(
# f"Your local rank {paddle.distributed.get_rank()} must have a state_dict. dp_rank:{process_rank}, src_rank:{src_rank}"
# )
# fake_state_dict = [nested_reduce_tensor(state_dict)]
# else:
# fake_state_dict = [None]

# paddle.distributed.broadcast_object_list(
# fake_state_dict,
# src=src_rank,
# group=dp_group,
# )
# fake_state_dict = fake_state_dict[0]
# if process_rank != src_rank:
# sync_state_dict = nested_empty_tensor(fake_state_dict)
# else:
# sync_state_dict = state_dict
# logger.info(f"SYNC-state-dict--{sync_state_dict.keys()}")
# sync_state_dict = nested_broadcast_tensor(sync_state_dict, src=src_rank, group=dp_group)
# if process_rank != src_rank:
# master_weights = state_dict.pop('master_weights', {})
# sync_state_dict['master_weights'].update(master_weights)
# sync_state_dict.update(state_dict)
# state_dict = sync_state_dict
# logger.info("broadcast_moe_optimizer done")
# return state_dict


def broadcast_moe_optimizer(state_dict):

try:
hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
src_rank = hcg.get_data_parallel_group_src_rank()
process_rank = paddle.distributed.get_rank()
data_parallel_rank = hcg.get_data_parallel_rank()

Check warning on line 285 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L280-L285

Added lines #L280 - L285 were not covered by tests
# Don't broadcast optimizer for dp rank is 1.
if dp_group.nranks <= 1:
return state_dict
except:
dp_group = None
src_rank = 0
data_parallel_rank = 0
process_rank = paddle.distributed.get_rank()

Check warning on line 293 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L287-L293

Added lines #L287 - L293 were not covered by tests

def _broadcast_moe_optimizer_state(state_dict):

Check warning on line 295 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L295

Added line #L295 was not covered by tests
# boardcast_keys
base_state_dict = {"master_weights": {}}
buf = [

Check warning on line 298 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L297-L298

Added lines #L297 - L298 were not covered by tests
{i: j.shape for i, j in state_dict.items() if i not in ["master_weights", "LR_Scheduler"]},
{i: j.shape for i, j in state_dict["master_weights"].items()},
{"LR_Scheduler": state_dict.get("LR_Scheduler", {})},
]

dist.broadcast_object_list(buf, src=src_rank, group=dp_group)

Check warning on line 304 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L304

Added line #L304 was not covered by tests
# logger.info(f"moe-optimizer-gather-keys{buf}")
for k, s in buf[0].items():
v = state_dict.get(k, paddle.zeros(s, "float32")).cuda()
v.name = k

Check warning on line 308 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L306-L308

Added lines #L306 - L308 were not covered by tests
# k = k.replace("_fp32_master_0", "")
dist.broadcast(v, src=src_rank, group=dp_group)
logger.info(f"broadcast moe optimizer {k} from {src_rank}")
base_state_dict[k] = v.cpu()
for k, s in buf[1].items():
v = state_dict["master_weights"].get(k, paddle.zeros(s, "float32")).cuda()
v.name = k
dist.broadcast(v, src=src_rank, group=dp_group)
logger.info(f"broadcast moe optimizer-master_weights {k} from {src_rank}")
base_state_dict["master_weights"][k] = v.cpu()
base_state_dict.update(buf[2])
return base_state_dict

Check warning on line 320 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L310-L320

Added lines #L310 - L320 were not covered by tests

base_state_dict = _broadcast_moe_optimizer_state(state_dict)
if data_parallel_rank > 0:
master_weight = state_dict.pop("master_weights", {})
base_state_dict.update(state_dict)
if master_weight:
if "master_weights" in base_state_dict:
base_state_dict["master_weights"].update(master_weight)

Check warning on line 328 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L322-L328

Added lines #L322 - L328 were not covered by tests
else:
base_state_dict["master_weights"] = master_weight
state_dict = base_state_dict
del base_state_dict
return state_dict

Check warning on line 333 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L330-L333

Added lines #L330 - L333 were not covered by tests
30 changes: 29 additions & 1 deletion paddlenlp/trainer/utils/reshard/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,16 @@ def _opt_name_to_tname(tensor_names, opt_names):
all_names.extend(opt_names)
all_names.sort()
pre_t_name = ""
suffix = [

Check warning on line 269 in paddlenlp/trainer/utils/reshard/common.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/common.py#L269

Added line #L269 was not covered by tests
"_fp32_master_0_beta1_pow_acc_0",
"_fp32_master_0_beta2_pow_acc_0",
"_fp32_master_0_moment1_0",
"_fp32_master_0_moment2_0",
"_beta1_pow_acc_0",
"_beta2_pow_acc_0",
"_moment1_0",
"_moment2_0",
]
opt_to_t = {}
for n in all_names:
if n in tensor_names:
Expand All @@ -274,6 +284,24 @@ def _opt_name_to_tname(tensor_names, opt_names):
else:
assert pre_t_name
opt_to_t[n] = pre_t_name

for t in opt_names:
_find = False
for s in suffix:
if t.endswith(s):
logger.info(f"{t}-{t[:-len(s)]}--{t[:-len(s)] in tensor_names}")
opt_to_t[t] = t[:-len(s)]
_find = True
break
assert _find

Check warning on line 296 in paddlenlp/trainer/utils/reshard/common.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/common.py#L288-L296

Added lines #L288 - L296 were not covered by tests
# opt_to_t = {}
# for n in all_names:
# if n in tensor_names:
# # we get a param
# pre_t_name = n
# else:
# assert pre_t_name
# opt_to_t[n] = pre_t_name
return opt_to_t

if structure_name_mapping is not None:
Expand All @@ -291,7 +319,7 @@ def _opt_name_to_tname(tensor_names, opt_names):
(self._model_weights, model_weights_tmp) = (model_weights_tmp, self._model_weights)
for k in list(model_weights_tmp.keys()):
t_name = structure_name_mapping[k]
self._model_weights[(k, t_name)] = model_weights_tmp[k].cpu()
self._model_weights[(k, t_name)] = paddle.to_tensor(model_weights_tmp[k]).cpu()

Check warning on line 322 in paddlenlp/trainer/utils/reshard/common.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/reshard/common.py#L322

Added line #L322 was not covered by tests
del model_weights_tmp[k]

# opt
Expand Down
Loading

0 comments on commit 2851da8

Please sign in to comment.