Skip to content

Commit ca5b811

Browse files
committed
feat: Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution
1 parent 2ea0b17 commit ca5b811

File tree

4 files changed

+13
-17
lines changed

4 files changed

+13
-17
lines changed

applications/ColossalMoE/colossal_moe/models/mixtral_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
110110
module = self.model.model
111111

112112
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
113-
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
113+
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
114114
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
115115
self.append_or_create_method_replacement(
116116
description=method_replacement, policy=policy, target_key=model_cls

colossalai/shardformer/policies/base_policy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
197197
"""
198198
return []
199199

200-
@staticmethod
201-
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
200+
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
202201
"""Divide layers into stages"""
203202
quotient = num_layers // num_stages
204203
remainder = num_layers % num_stages
@@ -213,8 +212,8 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
213212
layers_per_stage[i] += 1
214213
return layers_per_stage
215214

216-
@staticmethod
217215
def get_stage_index(
216+
self,
218217
layers_per_stage: List[int],
219218
stage: int,
220219
num_model_chunks: int = 1,

colossalai/shardformer/policies/gpt2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def get_held_layers(self) -> List[nn.Module]:
175175
layers_per_stage = self.distribute_layers(
176176
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
177177
)
178-
stage_indices = Policy.get_stage_index(
178+
stage_indices = self.get_stage_index(
179179
layers_per_stage,
180180
stage_manager.stage,
181181
num_model_chunks=stage_manager.num_model_chunks,
@@ -226,8 +226,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
226226
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
227227
}
228228
else:
229-
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
230-
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
229+
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
230+
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
231231
method_replacement = {
232232
"forward": partial(
233233
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config

examples/language/openmoe/model/openmoe_policy.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,11 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
9898
module = self.model.model
9999

100100
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
101-
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
101+
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
102102
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
103-
self.append_or_create_method_replacement(description=method_replacement,
104-
policy=policy,
105-
target_key=model_cls)
103+
self.append_or_create_method_replacement(
104+
description=method_replacement, policy=policy, target_key=model_cls
105+
)
106106

107107
return
108108

@@ -126,12 +126,9 @@ def get_held_layers(self) -> List[Module]:
126126
held_layers.append(module.norm)
127127

128128
return held_layers
129-
130-
@staticmethod
131-
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
132-
"""Divide layers into stages
133129

134-
"""
130+
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
131+
"""Divide layers into stages"""
135132
if num_layers == 24 and num_stages == 4:
136133
return [7, 7, 7, 3]
137134
elif num_layers == 24 and num_stages == 2:
@@ -142,7 +139,7 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
142139
return [8, 4]
143140
else:
144141
print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy")
145-
return Policy.distribute_layers(num_layers, num_stages)
142+
return super().distribute_layers(num_layers, num_stages)
146143

147144

148145
class OpenMoeModelPolicy(OpenMoePolicy):

0 commit comments

Comments
 (0)