15
15
HybridParallelModule ,
16
16
HybridParallelNaiveOptimizer ,
17
17
HybridParallelPlugin ,
18
+ HybridParallelZeroOptimizer ,
18
19
get_param_info ,
19
20
reinitialize_optimizer ,
20
21
)
21
22
from colossalai .checkpoint_io import MoECheckpointIO
22
23
from colossalai .cluster .process_group_mesh import ProcessGroupMesh
23
24
from colossalai .interface import ModelWrapper , OptimizerWrapper
24
25
from colossalai .tensor .moe_tensor .api import is_moe_tensor
25
- from colossalai .zero .low_level import LowLevelZeroOptimizer
26
26
27
- class MoeHybridParallelZeroOptimizer (LowLevelZeroOptimizer ):
27
+
28
+ class MoeHybridParallelZeroOptimizer (HybridParallelZeroOptimizer ):
28
29
def __init__ (
29
30
self ,
30
31
optimizer : Optimizer ,
31
32
model : Module ,
32
33
use_pipeline : bool ,
33
34
force_overlap_comm : bool , # force overlap comm
34
- dp_process_group : ProcessGroup , # dp pg for comm
35
+ dp_process_group : Optional [ProcessGroup ], # the dp pg for comm
36
+ tp_process_group : Optional [ProcessGroup ], # if using tp
37
+ pp_process_group : Optional [ProcessGroup ], # if using pp
35
38
moe_dp_group : ProcessGroup , # moe dp pg for comm
36
39
param_info : OrderedDict ,
37
40
initial_scale : int = 2 ** 16 , # grad scaler config
@@ -49,32 +52,28 @@ def __init__(
49
52
partition_grad : bool = False , # stage 2 flag
50
53
cpu_offload : bool = False , # cpu offload
51
54
forced_dtype : Optional [torch .dtype ] = None ,
52
- ):
53
-
55
+ ):
54
56
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
55
57
if not force_overlap_comm and (overlap_communication or partition_grad ):
56
- raise RuntimeError (WARN_STR + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True" )
57
-
58
+ raise RuntimeError (
59
+ WARN_STR
60
+ + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True"
61
+ )
62
+
58
63
if force_overlap_comm :
59
64
overlap_communication = True
60
65
warnings .warn (WARN_STR + " Please make sure of this." )
61
66
62
- self .param_info = param_info
63
- self .stage_manager = model .stage_manager
64
- self .shared_params = model .shared_params
65
- self .dp_pg = dp_process_group
66
-
67
- if use_pipeline :
68
- reinitialize_optimizer (optimizer , model )
69
-
70
67
pg_param_list = {
71
68
dp_process_group : list (filter (lambda p : not is_moe_tensor (p ), model .parameters ())),
72
69
moe_dp_group : list (filter (is_moe_tensor , model .parameters ())),
73
70
}
74
71
75
72
super ().__init__ (
73
+ model = model ,
76
74
optimizer = optimizer ,
77
- pg_to_param_list = pg_param_list ,
75
+ use_pipeline = use_pipeline ,
76
+ param_info = param_info ,
78
77
initial_scale = initial_scale ,
79
78
min_scale = min_scale ,
80
79
growth_factor = growth_factor ,
@@ -89,7 +88,12 @@ def __init__(
89
88
overlap_communication = overlap_communication ,
90
89
partition_grad = partition_grad ,
91
90
cpu_offload = cpu_offload ,
91
+ # dp_process_group=dp_process_group,
92
+ tp_process_group = tp_process_group ,
93
+ pp_process_group = pp_process_group ,
92
94
forced_dtype = forced_dtype ,
95
+ ## moe args
96
+ pg_to_param_list = pg_param_list ,
93
97
)
94
98
95
99
@@ -180,7 +184,7 @@ def configure(
180
184
optimizer , model , use_pipeline = self .enable_pipeline_parallelism , param_info = param_info
181
185
)
182
186
else :
183
- if not (self .dp_size > 1 or self .moe_dp_size > 1 ):
187
+ if not (self .dp_size > 1 or self .moe_dp_size > 1 ):
184
188
warnings .warn (
185
189
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
186
190
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
@@ -192,6 +196,8 @@ def configure(
192
196
force_overlap_comm = self .force_overlap_comm ,
193
197
param_info = param_info ,
194
198
dp_process_group = self .dp_group ,
199
+ tp_process_group = self .tp_group ,
200
+ pp_process_group = self .pp_group ,
195
201
moe_dp_group = self .moe_dp_group ,
196
202
verbose = True ,
197
203
clip_grad_norm = self .max_norm ,
0 commit comments