Skip to content

Commit

Permalink
[fea] moe support
Browse files Browse the repository at this point in the history
  • Loading branch information
bo-ke committed May 31, 2024
1 parent 0cd8fe7 commit 8894d32
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 43 deletions.
4 changes: 4 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -719,4 +719,8 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
Whether use flatten_param_grads method in optimizer,
only used on NPU devices.(default:False)
--use_expert_parallel
Whether to enable MoE (Mixture of Experts) expert parallel training.
(default: False)
```
100 changes: 69 additions & 31 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
from .utils import reshard as reshard_util
from .utils.helper import ( # nested_truncate,
broadcast_dp_optimizer,
broadcast_moe_optimizer,
distributed_concat,
distributed_file,
distributed_isfile,
Expand Down Expand Up @@ -565,7 +566,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
)
self.model.set_state_dict(state_dict)
else:
if resume_from_checkpoint is not None and self.args.dataset_rank == 0:
if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_expert_parallel):

Check warning on line 569 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L569

Added line #L569 was not covered by tests

weights_file = os.path.join(
resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)
Expand Down Expand Up @@ -930,22 +931,17 @@ def _inner_training_loop(
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
self.timers and self.timers("forward-backward").start()

dp_enabled = (
self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1
)
forbidden_no_sync = False
# stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API
# hybrid_parallel (tp or pp or sharding stage 1) should not no_sync
if self.args.use_hybrid_parallel:
forbidden_no_sync = True

availiable_no_sync = dp_enabled and not forbidden_no_sync

availiable_no_sync = hasattr(model, "no_sync")

Check warning on line 936 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L936

Added line #L936 was not covered by tests
is_no_sync = (
((step_control + 1) % args.gradient_accumulation_steps != 0)
and availiable_no_sync
and args._no_sync_in_gradient_accumulation
) or (args.recompute and availiable_no_sync)
(
((step_control + 1) % args.gradient_accumulation_steps != 0)
and args._no_sync_in_gradient_accumulation
)
or args.recompute
or args.use_expert_parallel
) and availiable_no_sync
# sharding
# stage1. the same as ddp
# stage2. manualy collect gradient on dp group
Expand All @@ -965,6 +961,14 @@ def _inner_training_loop(

tr_loss += tr_loss_step

def fused_allreduce_gradients_no_sync(paramlist, hcg):
paramlist = list(paramlist)
nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)]
moelist = [p for p in paramlist if getattr(p, "no_sync", False)]
if moelist and not self.args.use_expert_parallel:
logger.warning("found `no sync` param when `use_expert_parallel=False`")
fused_allreduce_gradients(nonmoe_list, hcg)

Check warning on line 970 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L964-L970

Added lines #L964 - L970 were not covered by tests

if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
Expand All @@ -983,12 +987,12 @@ def _inner_training_loop(

# Case 1: Use recompute and dp / sharding stage1,
# manualy collect gradient for dp.
if args.recompute and availiable_no_sync:
fused_allreduce_gradients(list(model.parameters()), None)
if (args.recompute or args.use_expert_parallel) and availiable_no_sync:
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

Check warning on line 991 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L990-L991

Added lines #L990 - L991 were not covered by tests

# Case 2: hack dp with master_grad
if dp_master_grad and not (args.recompute and availiable_no_sync):
fused_allreduce_gradients(list(model.parameters()), None)
elif dp_master_grad:
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

Check warning on line 995 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L994-L995

Added lines #L994 - L995 were not covered by tests

# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
Expand All @@ -1007,8 +1011,7 @@ def _inner_training_loop(
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)

if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)

fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg)

Check warning on line 1014 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1014

Added line #L1014 was not covered by tests
self.timers and self.timers("all-reduce").stop()
self.timers and self.timers("optimizer-step").start()

Expand All @@ -1028,6 +1031,8 @@ def _inner_training_loop(
)
optimizer_was_run = True
if self.do_grad_scaling:
if args.pipeline_parallel_degree > 1:
assert not self.args.use_expert_parallel, "pipeline moe not work under fp16"

Check warning on line 1035 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1034-L1035

Added lines #L1034 - L1035 were not covered by tests
scale_before = paddle.assign(self.scaler._scale)
self.scaler.step(self.optimizer)
self.scaler.update()
Expand Down Expand Up @@ -2042,7 +2047,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,

model.train()
inputs = self._prepare_inputs(inputs)

with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)

Expand All @@ -2053,7 +2057,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
self.scaler.scale(loss).backward()
else:
loss.backward()

return loss.detach()

def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
Expand Down Expand Up @@ -2143,6 +2146,26 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
# For ckpt integrity
paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done"))

def _filter_moe_no_sync_optimizer_params(self):
"""
filter optimizer params which should not sync
"""
state_dict = self.model.state_dict()
optimzier_state_dict = self.optimizer.state_dict()
filter_optimzier_state_dict = OrderedDict()
param_names_in_master_weights = list(optimzier_state_dict["master_weights"].keys()) if self.args.bf16 else []
filter_optimzier_state_dict["master_weights"] = OrderedDict()
for k, v in state_dict.items():
if getattr(v, "no_sync", False):
if v.name in param_names_in_master_weights:
filter_optimzier_state_dict["master_weights"][v.name] = optimzier_state_dict["master_weights"][

Check warning on line 2161 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2153-L2161

Added lines #L2153 - L2161 were not covered by tests
v.name
]
for op_k, op_v in optimzier_state_dict.items():
if op_k.startswith(v.name):
filter_optimzier_state_dict[op_k] = op_v
return filter_optimzier_state_dict

Check warning on line 2167 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2164-L2167

Added lines #L2164 - L2167 were not covered by tests

def _save_checkpoint(self, model, metrics=None):
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
self.runtime_timer.start("checkpoint saving time")
Expand All @@ -2165,7 +2188,7 @@ def _save_checkpoint(self, model, metrics=None):
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)

if self.args.use_hybrid_parallel:
if self.dp_group.rank <= 0:
if self.dp_group.rank <= 0 or self.args.use_expert_parallel:

Check warning on line 2191 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2191

Added line #L2191 was not covered by tests
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving optimizer files.")
if self.args.unified_checkpoint:
Expand All @@ -2177,12 +2200,18 @@ def _save_checkpoint(self, model, metrics=None):
safe_serialization=True,
)
else:
self._save_ckpt_func(
self.optimizer.state_dict(),
os.path.join(output_dir, optimizer_name),
)
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(

Check warning on line 2204 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2203-L2204

Added lines #L2203 - L2204 were not covered by tests
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
)
else:
self._save_ckpt_func(

Check warning on line 2209 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2209

Added line #L2209 was not covered by tests
self.optimizer.state_dict(),
os.path.join(output_dir, optimizer_name),
)

if self.args.should_save:
if self.args.should_save or self.args.use_expert_parallel:

Check warning on line 2214 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2214

Added line #L2214 was not covered by tests
if not self.args.use_hybrid_parallel:
logger.info("Saving optimizer files.")
if self.args.unified_checkpoint:
Expand All @@ -2194,7 +2223,12 @@ def _save_checkpoint(self, model, metrics=None):
safe_serialization=True,
)
else:
self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
if self.dp_group.rank > 0:
self._save_ckpt_func(

Check warning on line 2227 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2226-L2227

Added lines #L2226 - L2227 were not covered by tests
self._filter_moe_no_sync_optimizer_params(), os.path.join(output_dir, OPTIMIZER_NAME)
)
else:
self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

Check warning on line 2231 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2231

Added line #L2231 was not covered by tests

# FIXME: maybe only save one copy
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
Expand Down Expand Up @@ -2452,7 +2486,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
logger.info("Loading checkpoint, the next checkpoint will be saved as unified checkpoint")

if not use_unified_checkpoint:
if self.args.data_parallel_rank == 0:
if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel:

Check warning on line 2489 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2489

Added line #L2489 was not covered by tests
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
path = os.path.join(checkpoint, optimizer_name)
if os.path.isfile(path):
Expand All @@ -2476,7 +2510,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# broadcast optimizer state in dp group
if self.args.local_rank != -1:
dist.barrier()
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
if self.args.use_expert_parallel:
opt_state_dict = broadcast_moe_optimizer(opt_state_dict)

Check warning on line 2514 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2513-L2514

Added lines #L2513 - L2514 were not covered by tests
else:
if not self.args.should_load_sharding_stage1_model:
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)

Check warning on line 2517 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2516-L2517

Added lines #L2516 - L2517 were not covered by tests

if opt_state_dict is not None:
# Load in optimizer and scheduler states
Expand Down
29 changes: 26 additions & 3 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,10 @@ class TrainingArguments:
default=False,
metadata={"help": "whether to run distributed training in auto parallel mode"},
)
use_expert_parallel: Optional[bool] = field(
default=False,
metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down Expand Up @@ -1149,6 +1153,8 @@ def is_segment_parallel_supported():
order = ["dp", "sharding", "pp", "sep", "mp"]
else:
order = ["dp", "sharding", "pp", "mp"]
if self.use_expert_parallel:
order = order[1:-1] + ["dp", "mp"]

Check warning on line 1157 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1156-L1157

Added lines #L1156 - L1157 were not covered by tests

if is_segment_parallel_supported():
hybrid_configs = {
Expand Down Expand Up @@ -1640,8 +1646,12 @@ def optimizer_name_suffix(self):
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.sharding_parallel_degree > 1:
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
if self.use_expert_parallel:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))

Check warning on line 1650 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1649-L1650

Added lines #L1649 - L1650 were not covered by tests
return "_".join(name)
else:
if self.use_expert_parallel:
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)

Check warning on line 1654 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1653-L1654

Added lines #L1653 - L1654 were not covered by tests
return None

@property
Expand All @@ -1652,12 +1662,16 @@ def weight_name_suffix(self):
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
if self.pipeline_parallel_degree > 1:
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.use_expert_parallel:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))

Check warning on line 1666 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1665-L1666

Added lines #L1665 - L1666 were not covered by tests
return "_".join(name)

else:
if self.use_expert_parallel:

Check warning on line 1670 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1670

Added line #L1670 was not covered by tests
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
return None

def sharded_name_suffix(self, shard_id=None, pp_id=None):
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):

Check warning on line 1674 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1674

Added line #L1674 was not covered by tests
if self.use_hybrid_parallel:
name = []
if self.tensor_parallel_degree > 1:
Expand All @@ -1672,8 +1686,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None):
shard_id = self.sharding_parallel_rank
assert isinstance(shard_id, int)
name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree))
if self.use_expert_parallel:
if moe_id is None:
moe_id = self.data_parallel_rank
assert isinstance(moe_id, int)
name.append(self._format_name("moe", moe_id, self.data_parallel_degree))

Check warning on line 1693 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1689-L1693

Added lines #L1689 - L1693 were not covered by tests
return "_".join(name)
else:
if self.use_expert_parallel:
if moe_id is None:

Check warning on line 1697 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1696-L1697

Added lines #L1696 - L1697 were not covered by tests
moe_id = self.data_parallel_rank
return self._format_name("moe", moe_id, self.data_parallel_degree)

Check warning on line 1699 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1699

Added line #L1699 was not covered by tests
return None

@property
Expand Down Expand Up @@ -1766,9 +1789,9 @@ def should_save_model_state(self):
return True
elif self.use_hybrid_parallel:
# save on dataset rank 0
return self.sharding_parallel_rank == 0 and self.data_parallel_rank == 0
return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)

Check warning on line 1792 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1792

Added line #L1792 was not covered by tests
else:
return self.process_index == 0
return self.process_index == 0 or self.use_expert_parallel

@property
def _no_sync_in_gradient_accumulation(self):
Expand Down
56 changes: 56 additions & 0 deletions paddlenlp/trainer/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,59 @@ def broadcast_dp_optimizer(state_dict):
state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group)

return state_dict


def broadcast_moe_optimizer(state_dict):

try:
hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
src_rank = hcg.get_data_parallel_group_src_rank()
data_parallel_rank = hcg.get_data_parallel_rank()

Check warning on line 237 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L233-L237

Added lines #L233 - L237 were not covered by tests
# Don't broadcast optimizer for dp rank is 1.
if dp_group.nranks <= 1:
return state_dict
except:
dp_group = None
src_rank = 0
data_parallel_rank = 0

Check warning on line 244 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L239-L244

Added lines #L239 - L244 were not covered by tests

def _broadcast_moe_optimizer_state(state_dict):

Check warning on line 246 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L246

Added line #L246 was not covered by tests
# boardcast_keys
base_state_dict = {"master_weights": {}}
buf = [

Check warning on line 249 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L248-L249

Added lines #L248 - L249 were not covered by tests
{i: j.shape for i, j in state_dict.items() if i not in ["master_weights", "LR_Scheduler"]},
{i: j.shape for i, j in state_dict["master_weights"].items()},
{"LR_Scheduler": state_dict.get("LR_Scheduler", {})},
]

dist.broadcast_object_list(buf, src=src_rank, group=dp_group)

Check warning on line 255 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L255

Added line #L255 was not covered by tests
# logger.info(f"moe-optimizer-gather-keys{buf}")
for k, s in buf[0].items():
v = state_dict.get(k, paddle.zeros(s, "float32")).cuda()
v.name = k

Check warning on line 259 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L257-L259

Added lines #L257 - L259 were not covered by tests
# k = k.replace("_fp32_master_0", "")
dist.broadcast(v, src=src_rank, group=dp_group)
logger.info(f"broadcast moe optimizer {k} from {src_rank}")
base_state_dict[k] = v.cpu()
for k, s in buf[1].items():
v = state_dict["master_weights"].get(k, paddle.zeros(s, "float32")).cuda()
v.name = k
dist.broadcast(v, src=src_rank, group=dp_group)
logger.info(f"broadcast moe optimizer-master_weights {k} from {src_rank}")
base_state_dict["master_weights"][k] = v.cpu()
base_state_dict.update(buf[2])
return base_state_dict

Check warning on line 271 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L261-L271

Added lines #L261 - L271 were not covered by tests

base_state_dict = _broadcast_moe_optimizer_state(state_dict)
if data_parallel_rank > 0:
master_weight = state_dict.pop("master_weights", {})
base_state_dict.update(state_dict)
if master_weight:
if "master_weights" in base_state_dict:
base_state_dict["master_weights"].update(master_weight)

Check warning on line 279 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L273-L279

Added lines #L273 - L279 were not covered by tests
else:
base_state_dict["master_weights"] = master_weight
state_dict = base_state_dict
del base_state_dict
return state_dict

Check warning on line 284 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L281-L284

Added lines #L281 - L284 were not covered by tests
Loading

0 comments on commit 8894d32

Please sign in to comment.