Skip to content

Commit 09d6945

Browse files
committed
style: polish code
1 parent ca5b811 commit 09d6945

File tree

2 files changed

+45
-40
lines changed

2 files changed

+45
-40
lines changed

examples/language/openmoe/model/openmoe_policy.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from functools import partial
32
from typing import Callable, Dict, List, Optional, Union
43

@@ -21,7 +20,6 @@
2120

2221

2322
class OpenMoePolicy(Policy):
24-
2523
def config_sanity_check(self):
2624
pass
2725

@@ -43,7 +41,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
4341
if self.shard_config.enable_sequence_parallelism:
4442
self.shard_config.enable_sequence_parallelism = False
4543
raise NotImplementedError(
46-
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
44+
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
45+
)
4746

4847
if self.shard_config.enable_tensor_parallelism:
4948
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
@@ -143,7 +142,6 @@ def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
143142

144143

145144
class OpenMoeModelPolicy(OpenMoePolicy):
146-
147145
def __init__(self) -> None:
148146
super().__init__()
149147

@@ -169,21 +167,21 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
169167

170168

171169
class OpenMoeForCausalLMPolicy(OpenMoePolicy):
172-
173170
def module_policy(self):
174171
policy = super().module_policy()
175172

176173
if self.shard_config.enable_tensor_parallelism:
177174
# add a new item for casual lm
178175
new_item = {
179-
OpenMoeForCausalLM:
180-
ModulePolicyDescription(sub_module_replacement=[
176+
OpenMoeForCausalLM: ModulePolicyDescription(
177+
sub_module_replacement=[
181178
SubModuleReplacementDescription(
182179
suffix="lm_head",
183180
target_module=Linear1D_Col,
184181
kwargs=dict(gather_output=True),
185182
)
186-
])
183+
]
184+
)
187185
}
188186
policy.update(new_item)
189187

@@ -208,13 +206,17 @@ def get_held_layers(self) -> List[Module]:
208206
def get_shared_params(self) -> List[Dict[int, Tensor]]:
209207
llama_model = self.model.model
210208
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
211-
if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
212-
and self.pipeline_stage_manager.num_stages > 1):
209+
if (
210+
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
211+
and self.pipeline_stage_manager.num_stages > 1
212+
):
213213
# tie weights
214-
return [{
215-
0: llama_model.embed_tokens.weight,
216-
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
217-
}]
214+
return [
215+
{
216+
0: llama_model.embed_tokens.weight,
217+
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
218+
}
219+
]
218220
return []
219221

220222

@@ -247,12 +249,13 @@ def openmoe_model_forward(
247249

248250
logger = logging.get_logger(__name__)
249251

250-
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
251-
output_hidden_states = (output_hidden_states
252-
if output_hidden_states is not None else self.config.output_hidden_states)
252+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
253+
output_hidden_states = (
254+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
255+
)
253256
use_cache = use_cache if use_cache is not None else self.config.use_cache
254257

255-
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
258+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
256259

257260
# retrieve input_ids and inputs_embeds
258261
if stage_manager.is_first_stage():
@@ -320,7 +323,8 @@ def openmoe_model_forward(
320323
if self.gradient_checkpointing and self.training:
321324
if use_cache:
322325
logger.warning_once(
323-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
326+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
327+
)
324328
use_cache = False
325329

326330
# decoder layers
@@ -333,12 +337,11 @@ def openmoe_model_forward(
333337
if output_hidden_states:
334338
all_hidden_states += (hidden_states,)
335339

336-
past_key_value = (past_key_values[idx] if past_key_values is not None else None)
340+
past_key_value = past_key_values[idx] if past_key_values is not None else None
337341

338342
if self.gradient_checkpointing and self.training:
339343

340344
def create_custom_forward(module):
341-
342345
def custom_forward(*inputs):
343346
# None for past_key_value
344347
return module(*inputs, output_attentions, None)
@@ -384,14 +387,16 @@ def custom_forward(*inputs):
384387
router_z_loss = past_router_z_loss + router_z_loss
385388

386389
if stage_manager.is_last_stage():
387-
return tuple([
388-
hidden_states,
389-
next_cache,
390-
all_hidden_states,
391-
all_self_attns,
392-
router_aux_loss,
393-
router_z_loss,
394-
])
390+
return tuple(
391+
[
392+
hidden_states,
393+
next_cache,
394+
all_hidden_states,
395+
all_self_attns,
396+
router_aux_loss,
397+
router_z_loss,
398+
]
399+
)
395400
# always return dict for imediate stage
396401
return {
397402
"hidden_states": hidden_states,
@@ -445,10 +450,11 @@ def llama_for_causal_lm_forward(
445450
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
446451
```"""
447452
logger = logging.get_logger(__name__)
448-
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
449-
output_hidden_states = (output_hidden_states
450-
if output_hidden_states is not None else self.config.output_hidden_states)
451-
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
453+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
454+
output_hidden_states = (
455+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
456+
)
457+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
452458

453459
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
454460
if output_attentions:
@@ -504,7 +510,6 @@ def llama_for_causal_lm_forward(
504510
if chunk_head == True:
505511

506512
def create_custom_forward(module):
507-
508513
def custom_forward(*inputs):
509514
logits = module(inputs[0])
510515
logits = logits.float()
@@ -522,8 +527,8 @@ def custom_forward(*inputs):
522527
for batch_idx in range(hidden_states.shape[0]):
523528
loss = loss + torch.utils.checkpoint.checkpoint(
524529
create_custom_forward(self.lm_head),
525-
hidden_states[batch_idx:batch_idx + 1, :],
526-
labels[batch_idx:batch_idx + 1, :],
530+
hidden_states[batch_idx : batch_idx + 1, :],
531+
labels[batch_idx : batch_idx + 1, :],
527532
)
528533
logits = None
529534
else:

tests/test_booster/test_plugin/test_3d_plugin.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _criterion(outputs, inputs):
8383

8484
@parameterize("init_method", ["none", "lazy"])
8585
def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
86-
"""check gemini plugin over model zoo
86+
"""check hybrid plugin over model zoo
8787
8888
Args:
8989
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
@@ -260,7 +260,7 @@ def run_grad_acc_test(test_args):
260260
origin_model, origin_optimizer, dataloader=dataloader
261261
)
262262
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
263-
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
263+
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
264264

265265

266266
def run_dist(rank, world_size, port, early_stop: bool = True):
@@ -271,9 +271,9 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
271271

272272

273273
@rerun_if_address_is_in_use()
274-
def test_gemini_plugin(early_stop: bool = True):
274+
def test_3d_plugin(early_stop: bool = True):
275275
spawn(run_dist, 4, early_stop=early_stop)
276276

277277

278278
if __name__ == "__main__":
279-
test_gemini_plugin(early_stop=False)
279+
test_3d_plugin(early_stop=False)

0 commit comments

Comments
 (0)