Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fea] moe support #8498

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -719,4 +719,8 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
Whether use flatten_param_grads method in optimizer,
only used on NPU devices.(default:False)

--use_expert_parallel
Whether to enable MoE (Mixture of Experts) expert parallel training.
(default: False)

```
100 changes: 69 additions & 31 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
from .utils import reshard as reshard_util
ZHUI marked this conversation as resolved.
Show resolved Hide resolved
from .utils.helper import ( # nested_truncate,
broadcast_dp_optimizer,
broadcast_moe_optimizer,
bo-ke marked this conversation as resolved.
Show resolved Hide resolved
distributed_concat,
distributed_file,
distributed_isfile,
Expand Down Expand Up @@ -565,7 +566,7 @@
)
self.model.set_state_dict(state_dict)
else:
if resume_from_checkpoint is not None and self.args.dataset_rank == 0:
if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_expert_parallel):

Check warning on line 569 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L569

Added line #L569 was not covered by tests

weights_file = os.path.join(
resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)
Expand Down Expand Up @@ -930,22 +931,17 @@
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
self.timers and self.timers("forward-backward").start()

dp_enabled = (
self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1
)
forbidden_no_sync = False
# stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API
# hybrid_parallel (tp or pp or sharding stage 1) should not no_sync
if self.args.use_hybrid_parallel:
forbidden_no_sync = True

availiable_no_sync = dp_enabled and not forbidden_no_sync

availiable_no_sync = hasattr(model, "no_sync")

Check warning on line 936 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L936

Added line #L936 was not covered by tests
is_no_sync = (
((step_control + 1) % args.gradient_accumulation_steps != 0)
and availiable_no_sync
and args._no_sync_in_gradient_accumulation
) or (args.recompute and availiable_no_sync)
(
((step_control + 1) % args.gradient_accumulation_steps != 0)
and args._no_sync_in_gradient_accumulation
)
or args.recompute
or args.use_expert_parallel
) and availiable_no_sync
# sharding
# stage1. the same as ddp
# stage2. manualy collect gradient on dp group
Expand All @@ -965,6 +961,14 @@

tr_loss += tr_loss_step

def fused_allreduce_gradients_no_sync(paramlist, hcg):
paramlist = list(paramlist)
nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)]
moelist = [p for p in paramlist if getattr(p, "no_sync", False)]
if moelist and not self.args.use_expert_parallel:
logger.warning("found `no sync` param when `use_expert_parallel=False`")
fused_allreduce_gradients(nonmoe_list, hcg)

Check warning on line 970 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L964-L970

Added lines #L964 - L970 were not covered by tests

if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
Expand All @@ -983,12 +987,12 @@

# Case 1: Use recompute and dp / sharding stage1,
# manualy collect gradient for dp.
if args.recompute and availiable_no_sync:
fused_allreduce_gradients(list(model.parameters()), None)
if (args.recompute or args.use_expert_parallel) and availiable_no_sync:
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

Check warning on line 991 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L990-L991

Added lines #L990 - L991 were not covered by tests

# Case 2: hack dp with master_grad
if dp_master_grad and not (args.recompute and availiable_no_sync):
fused_allreduce_gradients(list(model.parameters()), None)
elif dp_master_grad:
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

Check warning on line 995 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L994-L995

Added lines #L994 - L995 were not covered by tests

# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
Expand All @@ -1007,8 +1011,7 @@
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)

if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)

fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg)

Check warning on line 1014 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1014

Added line #L1014 was not covered by tests
self.timers and self.timers("all-reduce").stop()
self.timers and self.timers("optimizer-step").start()

Expand All @@ -1028,6 +1031,8 @@
)
optimizer_was_run = True
if self.do_grad_scaling:
if args.pipeline_parallel_degree > 1:
assert not self.args.use_expert_parallel, "pipeline moe not work under fp16"

Check warning on line 1035 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1034-L1035

Added lines #L1034 - L1035 were not covered by tests
scale_before = paddle.assign(self.scaler._scale)
self.scaler.step(self.optimizer)
self.scaler.update()
Expand Down Expand Up @@ -2042,7 +2047,6 @@

model.train()
inputs = self._prepare_inputs(inputs)

with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)

Expand All @@ -2053,7 +2057,6 @@
self.scaler.scale(loss).backward()
else:
loss.backward()

return loss.detach()

def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
Expand Down Expand Up @@ -2143,6 +2146,26 @@
# For ckpt integrity
paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done"))

def _filter_moe_no_sync_optimizer_params(self):
"""
filter optimizer params which should not sync
"""
state_dict = self.model.state_dict()
optimzier_state_dict = self.optimizer.state_dict()
filter_optimzier_state_dict = OrderedDict()
param_names_in_master_weights = list(optimzier_state_dict["master_weights"].keys()) if self.args.bf16 else []
filter_optimzier_state_dict["master_weights"] = OrderedDict()
for k, v in state_dict.items():
if getattr(v, "no_sync", False):
if v.name in param_names_in_master_weights:
filter_optimzier_state_dict["master_weights"][v.name] = optimzier_state_dict["master_weights"][

Check warning on line 2161 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2153-L2161

Added lines #L2153 - L2161 were not covered by tests
v.name
]
for op_k, op_v in optimzier_state_dict.items():
if op_k.startswith(v.name):
filter_optimzier_state_dict[op_k] = op_v
return filter_optimzier_state_dict

Check warning on line 2167 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2164-L2167

Added lines #L2164 - L2167 were not covered by tests

def _save_checkpoint(self, model, metrics=None):
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
self.runtime_timer.start("checkpoint saving time")
Expand All @@ -2165,7 +2188,7 @@
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)

if self.args.use_hybrid_parallel:
if self.dp_group.rank <= 0:
if self.dp_group.rank <= 0 or self.args.use_expert_parallel:

Check warning on line 2191 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2191

Added line #L2191 was not covered by tests
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving optimizer files.")
if self.args.unified_checkpoint:
Expand All @@ -2177,12 +2200,18 @@
safe_serialization=True,
)
else:
self._save_ckpt_func(
self.optimizer.state_dict(),
os.path.join(output_dir, optimizer_name),
)
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(

Check warning on line 2204 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2203-L2204

Added lines #L2203 - L2204 were not covered by tests
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
)
else:
self._save_ckpt_func(

Check warning on line 2209 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2209

Added line #L2209 was not covered by tests
self.optimizer.state_dict(),
os.path.join(output_dir, optimizer_name),
)

if self.args.should_save:
if self.args.should_save or self.args.use_expert_parallel:

Check warning on line 2214 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2214

Added line #L2214 was not covered by tests
bo-ke marked this conversation as resolved.
Show resolved Hide resolved
if not self.args.use_hybrid_parallel:
logger.info("Saving optimizer files.")
if self.args.unified_checkpoint:
Expand All @@ -2194,7 +2223,12 @@
safe_serialization=True,
)
else:
self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
if self.dp_group.rank > 0:
self._save_ckpt_func(

Check warning on line 2227 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2226-L2227

Added lines #L2226 - L2227 were not covered by tests
self._filter_moe_no_sync_optimizer_params(), os.path.join(output_dir, OPTIMIZER_NAME)
)
else:
self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

Check warning on line 2231 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2231

Added line #L2231 was not covered by tests

# FIXME: maybe only save one copy
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
Expand Down Expand Up @@ -2452,7 +2486,7 @@
logger.info("Loading checkpoint, the next checkpoint will be saved as unified checkpoint")

if not use_unified_checkpoint:
if self.args.data_parallel_rank == 0:
if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel:

Check warning on line 2489 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2489

Added line #L2489 was not covered by tests
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
path = os.path.join(checkpoint, optimizer_name)
if os.path.isfile(path):
Expand All @@ -2476,7 +2510,11 @@
# broadcast optimizer state in dp group
if self.args.local_rank != -1:
dist.barrier()
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
if self.args.use_expert_parallel:
opt_state_dict = broadcast_moe_optimizer(opt_state_dict)

Check warning on line 2514 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2513-L2514

Added lines #L2513 - L2514 were not covered by tests
else:
if not self.args.should_load_sharding_stage1_model:
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)

Check warning on line 2517 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2516-L2517

Added lines #L2516 - L2517 were not covered by tests

if opt_state_dict is not None:
# Load in optimizer and scheduler states
Expand Down
29 changes: 26 additions & 3 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,10 @@
default=False,
metadata={"help": "whether to run distributed training in auto parallel mode"},
)
use_expert_parallel: Optional[bool] = field(
default=False,
metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down Expand Up @@ -1149,6 +1153,8 @@
order = ["dp", "sharding", "pp", "sep", "mp"]
else:
order = ["dp", "sharding", "pp", "mp"]
if self.use_expert_parallel:
order = order[1:-1] + ["dp", "mp"]

Check warning on line 1157 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1156-L1157

Added lines #L1156 - L1157 were not covered by tests

if is_segment_parallel_supported():
hybrid_configs = {
Expand Down Expand Up @@ -1640,8 +1646,12 @@
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.sharding_parallel_degree > 1:
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
if self.use_expert_parallel:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))

Check warning on line 1650 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1649-L1650

Added lines #L1649 - L1650 were not covered by tests
return "_".join(name)
else:
if self.use_expert_parallel:
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)

Check warning on line 1654 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1653-L1654

Added lines #L1653 - L1654 were not covered by tests
return None

@property
Expand All @@ -1652,12 +1662,16 @@
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
if self.pipeline_parallel_degree > 1:
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.use_expert_parallel:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))

Check warning on line 1666 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1665-L1666

Added lines #L1665 - L1666 were not covered by tests
return "_".join(name)

else:
if self.use_expert_parallel:

Check warning on line 1670 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1670

Added line #L1670 was not covered by tests
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
return None

def sharded_name_suffix(self, shard_id=None, pp_id=None):
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):

Check warning on line 1674 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1674

Added line #L1674 was not covered by tests
if self.use_hybrid_parallel:
name = []
if self.tensor_parallel_degree > 1:
Expand All @@ -1672,8 +1686,17 @@
shard_id = self.sharding_parallel_rank
assert isinstance(shard_id, int)
name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree))
if self.use_expert_parallel:
if moe_id is None:
moe_id = self.data_parallel_rank
assert isinstance(moe_id, int)
name.append(self._format_name("moe", moe_id, self.data_parallel_degree))

Check warning on line 1693 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1689-L1693

Added lines #L1689 - L1693 were not covered by tests
return "_".join(name)
else:
if self.use_expert_parallel:
if moe_id is None:

Check warning on line 1697 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1696-L1697

Added lines #L1696 - L1697 were not covered by tests
moe_id = self.data_parallel_rank
return self._format_name("moe", moe_id, self.data_parallel_degree)

Check warning on line 1699 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1699

Added line #L1699 was not covered by tests
return None

@property
Expand Down Expand Up @@ -1766,9 +1789,9 @@
return True
elif self.use_hybrid_parallel:
# save on dataset rank 0
return self.sharding_parallel_rank == 0 and self.data_parallel_rank == 0
return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)

Check warning on line 1792 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1792

Added line #L1792 was not covered by tests
else:
return self.process_index == 0
return self.process_index == 0 or self.use_expert_parallel

@property
def _no_sync_in_gradient_accumulation(self):
Expand Down
56 changes: 56 additions & 0 deletions paddlenlp/trainer/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,59 @@
state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group)

return state_dict


def broadcast_moe_optimizer(state_dict):
bo-ke marked this conversation as resolved.
Show resolved Hide resolved

try:
hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
src_rank = hcg.get_data_parallel_group_src_rank()
data_parallel_rank = hcg.get_data_parallel_rank()

Check warning on line 237 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L233-L237

Added lines #L233 - L237 were not covered by tests
# Don't broadcast optimizer for dp rank is 1.
if dp_group.nranks <= 1:
return state_dict
except:
dp_group = None
src_rank = 0
data_parallel_rank = 0

Check warning on line 244 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L239-L244

Added lines #L239 - L244 were not covered by tests

def _broadcast_moe_optimizer_state(state_dict):

Check warning on line 246 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L246

Added line #L246 was not covered by tests
# boardcast_keys
base_state_dict = {"master_weights": {}}
buf = [

Check warning on line 249 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L248-L249

Added lines #L248 - L249 were not covered by tests
{i: j.shape for i, j in state_dict.items() if i not in ["master_weights", "LR_Scheduler"]},
{i: j.shape for i, j in state_dict["master_weights"].items()},
{"LR_Scheduler": state_dict.get("LR_Scheduler", {})},
]

dist.broadcast_object_list(buf, src=src_rank, group=dp_group)

Check warning on line 255 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L255

Added line #L255 was not covered by tests
# logger.info(f"moe-optimizer-gather-keys{buf}")
for k, s in buf[0].items():
v = state_dict.get(k, paddle.zeros(s, "float32")).cuda()
v.name = k

Check warning on line 259 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L257-L259

Added lines #L257 - L259 were not covered by tests
# k = k.replace("_fp32_master_0", "")
dist.broadcast(v, src=src_rank, group=dp_group)
logger.info(f"broadcast moe optimizer {k} from {src_rank}")
base_state_dict[k] = v.cpu()
for k, s in buf[1].items():
v = state_dict["master_weights"].get(k, paddle.zeros(s, "float32")).cuda()
v.name = k
dist.broadcast(v, src=src_rank, group=dp_group)
logger.info(f"broadcast moe optimizer-master_weights {k} from {src_rank}")
base_state_dict["master_weights"][k] = v.cpu()
base_state_dict.update(buf[2])
return base_state_dict

Check warning on line 271 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L261-L271

Added lines #L261 - L271 were not covered by tests
bo-ke marked this conversation as resolved.
Show resolved Hide resolved

base_state_dict = _broadcast_moe_optimizer_state(state_dict)
if data_parallel_rank > 0:
master_weight = state_dict.pop("master_weights", {})
base_state_dict.update(state_dict)
if master_weight:
if "master_weights" in base_state_dict:
base_state_dict["master_weights"].update(master_weight)

Check warning on line 279 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L273-L279

Added lines #L273 - L279 were not covered by tests
else:
base_state_dict["master_weights"] = master_weight
state_dict = base_state_dict
del base_state_dict
return state_dict

Check warning on line 284 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L281-L284

Added lines #L281 - L284 were not covered by tests
Loading
Loading