Skip to content

Commit a613edd

Browse files
Hz188botbw
authored andcommitted
solve hang when parallel mode = pp + dp
1 parent 0210bea commit a613edd

File tree

3 files changed

+57
-34
lines changed

3 files changed

+57
-34
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from colossalai.cluster import ProcessGroupMesh
2828
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
2929
from colossalai.interface.optimizer import DistributedOptim
30+
from colossalai.logging import get_dist_logger
3031
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
3132
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
3233
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -1068,8 +1069,10 @@ def __init__(
10681069
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
10691070
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
10701071

1071-
self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0])
1072-
1072+
self.logger.info(
1073+
f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0]
1074+
)
1075+
10731076
self.stage_manager = None
10741077
self.schedule = None
10751078
self.custom_policy = custom_policy

colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,26 @@
1515
HybridParallelModule,
1616
HybridParallelNaiveOptimizer,
1717
HybridParallelPlugin,
18+
HybridParallelZeroOptimizer,
1819
get_param_info,
1920
reinitialize_optimizer,
2021
)
2122
from colossalai.checkpoint_io import MoECheckpointIO
2223
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
2324
from colossalai.interface import ModelWrapper, OptimizerWrapper
2425
from colossalai.tensor.moe_tensor.api import is_moe_tensor
25-
from colossalai.zero.low_level import LowLevelZeroOptimizer
2626

27-
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
27+
28+
class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
2829
def __init__(
2930
self,
3031
optimizer: Optimizer,
3132
model: Module,
3233
use_pipeline: bool,
3334
force_overlap_comm: bool, # force overlap comm
34-
dp_process_group: ProcessGroup, # dp pg for comm
35+
dp_process_group: Optional[ProcessGroup], # the dp pg for comm
36+
tp_process_group: Optional[ProcessGroup], # if using tp
37+
pp_process_group: Optional[ProcessGroup], # if using pp
3538
moe_dp_group: ProcessGroup, # moe dp pg for comm
3639
param_info: OrderedDict,
3740
initial_scale: int = 2**16, # grad scaler config
@@ -49,32 +52,28 @@ def __init__(
4952
partition_grad: bool = False, # stage 2 flag
5053
cpu_offload: bool = False, # cpu offload
5154
forced_dtype: Optional[torch.dtype] = None,
52-
):
53-
55+
):
5456
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
5557
if not force_overlap_comm and (overlap_communication or partition_grad):
56-
raise RuntimeError(WARN_STR + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True")
57-
58+
raise RuntimeError(
59+
WARN_STR
60+
+ " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True"
61+
)
62+
5863
if force_overlap_comm:
5964
overlap_communication = True
6065
warnings.warn(WARN_STR + " Please make sure of this.")
6166

62-
self.param_info = param_info
63-
self.stage_manager = model.stage_manager
64-
self.shared_params = model.shared_params
65-
self.dp_pg = dp_process_group
66-
67-
if use_pipeline:
68-
reinitialize_optimizer(optimizer, model)
69-
7067
pg_param_list = {
7168
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
7269
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
7370
}
7471

7572
super().__init__(
73+
model=model,
7674
optimizer=optimizer,
77-
pg_to_param_list=pg_param_list,
75+
use_pipeline=use_pipeline,
76+
param_info=param_info,
7877
initial_scale=initial_scale,
7978
min_scale=min_scale,
8079
growth_factor=growth_factor,
@@ -89,7 +88,12 @@ def __init__(
8988
overlap_communication=overlap_communication,
9089
partition_grad=partition_grad,
9190
cpu_offload=cpu_offload,
91+
# dp_process_group=dp_process_group,
92+
tp_process_group=tp_process_group,
93+
pp_process_group=pp_process_group,
9294
forced_dtype=forced_dtype,
95+
## moe args
96+
pg_to_param_list=pg_param_list,
9397
)
9498

9599

@@ -180,7 +184,7 @@ def configure(
180184
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
181185
)
182186
else:
183-
if not(self.dp_size > 1 or self.moe_dp_size > 1):
187+
if not (self.dp_size > 1 or self.moe_dp_size > 1):
184188
warnings.warn(
185189
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
186190
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
@@ -192,6 +196,8 @@ def configure(
192196
force_overlap_comm=self.force_overlap_comm,
193197
param_info=param_info,
194198
dp_process_group=self.dp_group,
199+
tp_process_group=self.tp_group,
200+
pp_process_group=self.pp_group,
195201
moe_dp_group=self.moe_dp_group,
196202
verbose=True,
197203
clip_grad_norm=self.max_norm,

tests/test_shardformer/test_model/test_shard_mixtral.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -117,37 +117,51 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
117117
"tp_size": 1,
118118
"pp_size": 1,
119119
"ep_size": 1,
120-
"zero_stage": 2,
120+
"zero_stage": 1,
121+
"overlap_communication": False,
121122
"precision": "fp32",
122-
}, # [dp(2) + pp(2)] + [moe_dp(4)]
123-
# {
124-
# "tp_size": 1,
125-
# "pp_size": 2,
126-
# "num_microbatches": 2,
127-
# "ep_size": 1,
128-
# "zero_stage": 1,
129-
# "precision": "fp32",
130-
# }, # [dp(2) + pp(2)] + [moe_dp(4)]
123+
}, # [dp(4)] + [moe_dp(4)]
124+
{
125+
"tp_size": 1,
126+
"pp_size": 2,
127+
"num_microbatches": 2,
128+
"ep_size": 1,
129+
"zero_stage": 1,
130+
"overlap_communication": False,
131+
"precision": "fp32",
132+
}, # [dp(2) + pp(2)] + [moe_pp(2)]
133+
{
134+
"tp_size": 2,
135+
"pp_size": 2,
136+
"num_microbatches": 2,
137+
"ep_size": 1,
138+
"zero_stage": 1,
139+
"overlap_communication": False,
140+
"precision": "fp32",
141+
}, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
131142
# {
132143
# "tp_size": 1,
133144
# "pp_size": 2,
134145
# "num_microbatches": 2,
135-
# "ep_size": 4,
146+
# "ep_size": 2,
136147
# "zero_stage": 1,
148+
# "overlap_communication": False,
137149
# "precision": "fp32",
138150
# }, # [dp(2) + pp(2)] + [ep(4))]
139151
# {
140152
# "tp_size": 1,
141153
# "pp_size": 1,
142154
# "ep_size": 2,
143155
# "zero_stage": 0,
156+
# "overlap_communication": False,
144157
# "precision": "fp32",
145158
# }, # [dp(4)] + [ep(2) + moe_tp(2)]
146159
# {
147-
# "tp_size": 1,
148-
# "pp_size": 1,
149-
# "ep_size": 4,
150-
# "zero_stage": 0,
160+
# "tp_size": 1,
161+
# "pp_size": 1,
162+
# "ep_size": 4,
163+
# "overlap_communication": False,
164+
# "zero_stage": 0,
151165
# "precision": "fp32"
152166
# }, # full dp for non-moe and full ep for moe
153167
],

0 commit comments

Comments
 (0)