1
1
import warnings
2
2
from types import MethodType
3
3
from typing import Callable , Optional , OrderedDict , Tuple
4
+ import numpy as np
4
5
5
6
import torch
6
7
import torch .distributed as dist
@@ -64,6 +65,14 @@ def __init__(
64
65
overlap_communication = True
65
66
warnings .warn (WARN_STR + " Please make sure of this." )
66
67
68
+ self .param_info = param_info
69
+ self .stage_manager = model .stage_manager
70
+ self .shared_params = model .shared_params
71
+ self .dp_pg = dp_process_group
72
+
73
+ if use_pipeline :
74
+ reinitialize_optimizer (optimizer , model )
75
+
67
76
pg_param_list = {
68
77
dp_process_group : list (filter (lambda p : not is_moe_tensor (p ), model .parameters ())),
69
78
moe_dp_group : list (filter (is_moe_tensor , model .parameters ())),
@@ -116,17 +125,16 @@ def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False,
116
125
raise NotImplementedError
117
126
118
127
world_size = dist .get_world_size ()
119
-
120
- self .moe_dp_size = world_size // (ep_size * moe_tp_size )
128
+ self .moe_dp_size = world_size // (self .pp_size * ep_size * moe_tp_size )
121
129
self .ep_size = ep_size
122
130
self .moe_tp_size = moe_tp_size
123
131
124
- self .moe_pg_mesh = ProcessGroupMesh (self .moe_dp_size , self .ep_size , self .moe_tp_size )
125
- self .moe_dp_axis , self .ep_axis , self .moe_tp_axis = 0 , 1 , 2
132
+ if self .pp_size * self .moe_dp_size * self .ep_size * self .moe_tp_size != world_size :
133
+ raise ValueError (
134
+ f"world_size={ world_size } is not divisible by pp_size={ self .pp_size } * moe_dp_size={ self .moe_dp_size } * ep_size={ self .ep_size } * moe_tp_size={ self .moe_tp_size } "
135
+ )
126
136
127
- self .moe_dp_group = self .moe_pg_mesh .get_group_along_axis (self .moe_dp_axis )
128
- self .ep_group = self .moe_pg_mesh .get_group_along_axis (self .ep_axis )
129
- self .moe_tp_group = self .moe_pg_mesh .get_group_along_axis (self .moe_tp_axis )
137
+ self ._init_moe_param_comm ()
130
138
131
139
self .logger .info (f"{ type (self ).__name__ } : { self .ep_size = } { self .moe_dp_size = } { self .moe_tp_size = } " , ranks = [0 ])
132
140
@@ -136,6 +144,52 @@ def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False,
136
144
137
145
self .force_overlap_comm = force_overlap_comm
138
146
147
+ def _init_moe_param_comm (self ):
148
+ self .moe_dp_group = None
149
+ self .ep_group = None
150
+ self .moe_tp_group = None
151
+
152
+ # create submesh for ep, moe_dp, moe_tp
153
+ ranks_by_pp_stage = self .pg_mesh .get_group_along_axis (
154
+ [self .dp_axis , self .tp_axis , self .sp_axis ], return_ranks_by_group = True
155
+ )
156
+
157
+ global_rank = self .pg_mesh .rank
158
+ pp_rank = self .pg_mesh .coordinate (self .pp_axis )
159
+
160
+ # create groups from submesh
161
+ for stage_idx , stage_rank in enumerate (ranks_by_pp_stage ):
162
+ # axis 0 is dp, axis 1 is tp, axis 2 is sp
163
+ submesh = np .array (stage_rank ).reshape (self .moe_dp_size , self .ep_size , self .moe_tp_size )
164
+
165
+ # hardcode here since we only have 3 axis
166
+ # moe_dp_group
167
+ for ep_idx in range (self .ep_size ):
168
+ for moe_tp_idx in range (self .moe_tp_size ):
169
+ moe_dp_ranks = submesh [:, ep_idx , moe_tp_idx ].flatten ().tolist ()
170
+ group = dist .new_group (moe_dp_ranks )
171
+ if pp_rank == stage_idx and global_rank in moe_dp_ranks :
172
+ assert self .moe_dp_group is None
173
+ self .moe_dp_group = group
174
+ # ep_group
175
+ for moe_dp_idx in range (self .moe_dp_size ):
176
+ for moe_tp_idx in range (self .moe_tp_size ):
177
+ ep_ranks = submesh [moe_dp_idx , :, moe_tp_idx ].flatten ().tolist ()
178
+ group = dist .new_group (ep_ranks )
179
+ if pp_rank == stage_idx and global_rank in ep_ranks :
180
+ assert self .ep_group is None
181
+ self .ep_group = group
182
+ # moe_tp_group
183
+ for moe_dp_idx in range (self .moe_dp_size ):
184
+ for ep_idx in range (self .ep_size ):
185
+ moe_tp_ranks = submesh [moe_dp_idx , ep_idx , :].flatten ().tolist ()
186
+ group = dist .new_group (moe_tp_ranks )
187
+ if pp_rank == stage_idx and global_rank in moe_tp_ranks :
188
+ assert self .moe_tp_group is None
189
+ self .moe_tp_group = group
190
+
191
+ self .logger .info (f"rank { dist .get_rank ()} moe_dp_group { dist .get_process_group_ranks (self .moe_dp_group )} ep_group { dist .get_process_group_ranks (self .ep_group )} moe_tp_group { dist .get_process_group_ranks (self .moe_tp_group )} " )
192
+
139
193
def get_checkpoint_io (self ) -> MoECheckpointIO :
140
194
return MoECheckpointIO (
141
195
self .dp_group , self .pp_group , self .tp_group , self .ep_group , self .moe_dp_group , self .zero_stage
0 commit comments