Skip to content

Commit 2f9bce6

Browse files
committed
[moe] implement submesh initialization
1 parent a613edd commit 2f9bce6

File tree

3 files changed

+98
-27
lines changed

3 files changed

+98
-27
lines changed

colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22
from types import MethodType
33
from typing import Callable, Optional, OrderedDict, Tuple
4+
import numpy as np
45

56
import torch
67
import torch.distributed as dist
@@ -64,6 +65,14 @@ def __init__(
6465
overlap_communication = True
6566
warnings.warn(WARN_STR + " Please make sure of this.")
6667

68+
self.param_info = param_info
69+
self.stage_manager = model.stage_manager
70+
self.shared_params = model.shared_params
71+
self.dp_pg = dp_process_group
72+
73+
if use_pipeline:
74+
reinitialize_optimizer(optimizer, model)
75+
6776
pg_param_list = {
6877
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
6978
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
@@ -116,17 +125,16 @@ def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False,
116125
raise NotImplementedError
117126

118127
world_size = dist.get_world_size()
119-
120-
self.moe_dp_size = world_size // (ep_size * moe_tp_size)
128+
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size)
121129
self.ep_size = ep_size
122130
self.moe_tp_size = moe_tp_size
123131

124-
self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
125-
self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
132+
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size:
133+
raise ValueError(
134+
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
135+
)
126136

127-
self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
128-
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
129-
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
137+
self._init_moe_param_comm()
130138

131139
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
132140

@@ -136,6 +144,52 @@ def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False,
136144

137145
self.force_overlap_comm = force_overlap_comm
138146

147+
def _init_moe_param_comm(self):
148+
self.moe_dp_group = None
149+
self.ep_group = None
150+
self.moe_tp_group = None
151+
152+
# create submesh for ep, moe_dp, moe_tp
153+
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
154+
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
155+
)
156+
157+
global_rank = self.pg_mesh.rank
158+
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
159+
160+
# create groups from submesh
161+
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
162+
# axis 0 is dp, axis 1 is tp, axis 2 is sp
163+
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
164+
165+
# hardcode here since we only have 3 axis
166+
# moe_dp_group
167+
for ep_idx in range(self.ep_size):
168+
for moe_tp_idx in range(self.moe_tp_size):
169+
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
170+
group = dist.new_group(moe_dp_ranks)
171+
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
172+
assert self.moe_dp_group is None
173+
self.moe_dp_group = group
174+
# ep_group
175+
for moe_dp_idx in range(self.moe_dp_size):
176+
for moe_tp_idx in range(self.moe_tp_size):
177+
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
178+
group = dist.new_group(ep_ranks)
179+
if pp_rank == stage_idx and global_rank in ep_ranks:
180+
assert self.ep_group is None
181+
self.ep_group = group
182+
# moe_tp_group
183+
for moe_dp_idx in range(self.moe_dp_size):
184+
for ep_idx in range(self.ep_size):
185+
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
186+
group = dist.new_group(moe_tp_ranks)
187+
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
188+
assert self.moe_tp_group is None
189+
self.moe_tp_group = group
190+
191+
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}")
192+
139193
def get_checkpoint_io(self) -> MoECheckpointIO:
140194
return MoECheckpointIO(
141195
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage

colossalai/cluster/process_group_mesh.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,15 @@ def create_group_along_axis(
209209
axis: Union[int, List[int]],
210210
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
211211
backend: Optional[str] = None,
212-
) -> ProcessGroup:
212+
return_ranks_by_group: bool = False
213+
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
213214
"""Create all process groups along the given axis, and return the one which the current process belongs to.
214215
215216
Args:
216217
axis (int): Axis along which the process groups are created.
217218
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
218219
backend (Optional[str], optional): Backend of the process group. Defaults to None.
220+
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
219221
220222
Returns:
221223
ProcessGroup: The process group along the given axis which the current process belongs to.
@@ -235,25 +237,35 @@ def create_group_along_axis(
235237
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
236238
for ax in axis:
237239
reduced_shape[ax] = 1
238-
target_group = None
239-
# use Cartesian product to generate all combinations of coordinates
240-
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
241-
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
242-
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
243-
group = self._get_group(ranks_in_group, backend=backend)
244-
if self._rank in ranks_in_group:
245-
target_group = group
246-
return target_group
240+
if return_ranks_by_group:
241+
ranks_by_group = []
242+
# use Cartesian product to generate all combinations of coordinates
243+
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
244+
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
245+
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
246+
ranks_by_group.append(ranks_in_group)
247+
return ranks_by_group
248+
else:
249+
target_group = None
250+
# use Cartesian product to generate all combinations of coordinates
251+
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
252+
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
253+
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
254+
group = self._get_group(ranks_in_group, backend=backend)
255+
if self._rank in ranks_in_group:
256+
target_group = group
257+
return target_group
247258

248259
def get_group_along_axis(
249-
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
250-
) -> ProcessGroup:
260+
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None, return_ranks_by_group: bool = False
261+
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
251262
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
252263
253264
Args:
254265
axis (int or list of int): Axes along which the process groups are created.
255266
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
256267
backend (Optional[str], optional): Backend of the process group. Defaults to None.
268+
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
257269
258270
Returns:
259271
ProcessGroup: The process group along the given axis which the current process belongs to.
@@ -267,6 +279,10 @@ def get_group_along_axis(
267279

268280
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
269281
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
282+
283+
if return_ranks_by_group:
284+
return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True)
285+
270286
if ranks_in_group not in self._ranks_to_group:
271287
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
272288
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)

tests/test_shardformer/test_model/test_shard_mixtral.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
2929
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
3030
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
3131
)
32-
with torch.autograd.set_detect_anomaly(True):
33-
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
34-
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
35-
)
32+
33+
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
34+
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
35+
)
3636

3737
stage_manager = booster.plugin.stage_manager
3838
tp_group = booster.plugin.tp_group
@@ -115,8 +115,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
115115
[
116116
{
117117
"tp_size": 1,
118-
"pp_size": 1,
119-
"ep_size": 1,
118+
"pp_size": 2,
119+
"num_microbatches": 2,
120+
"ep_size": 2,
120121
"zero_stage": 1,
121122
"overlap_communication": False,
122123
"precision": "fp32",
@@ -125,7 +126,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
125126
"tp_size": 1,
126127
"pp_size": 2,
127128
"num_microbatches": 2,
128-
"ep_size": 1,
129+
"ep_size": 2,
129130
"zero_stage": 1,
130131
"overlap_communication": False,
131132
"precision": "fp32",
@@ -134,7 +135,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
134135
"tp_size": 2,
135136
"pp_size": 2,
136137
"num_microbatches": 2,
137-
"ep_size": 1,
138+
"ep_size": 2,
138139
"zero_stage": 1,
139140
"overlap_communication": False,
140141
"precision": "fp32",

0 commit comments

Comments
 (0)