66
77import colossalai .shardformer .layer as col_nn
88
9- from ..layer .fused_ops import Bias_Gelu
109from ..modeling .gpt2 import (
1110 GPT2PipelineForwards ,
1211 get_gpt2_flash_attention_forward ,
1312 get_gpt_model_forward_for_flash_attn ,
13+ get_jit_fused_gpt2_mlp_forward ,
1414 get_lm_forward_with_dist_cross_entropy ,
1515 gpt2_sequence_parallel_forward_fn ,
1616)
@@ -37,10 +37,11 @@ def preprocess(self):
3737 """
3838 self .tie_weight = self .tie_weight_check ()
3939 self .origin_attn_implement = self .model .config ._attn_implementation
40+ self .enable_bias_gelu_fused = self .model .config .activation_function == "gelu"
4041 return self .model
4142
4243 def module_policy (self ):
43- from transformers .models .gpt2 .modeling_gpt2 import GPT2Attention , GPT2Block , GPT2Model
44+ from transformers .models .gpt2 .modeling_gpt2 import GPT2MLP , GPT2Attention , GPT2Block , GPT2Model
4445
4546 ATTN_IMPLEMENTATION = {
4647 "eager" : GPT2Attention ,
@@ -120,6 +121,7 @@ def module_policy(self):
120121 "n_fused" : 1 ,
121122 "seq_parallel_mode" : sp_mode ,
122123 "overlap" : overlap ,
124+ "skip_bias_add" : self .enable_bias_gelu_fused ,
123125 },
124126 ),
125127 SubModuleReplacementDescription (
@@ -200,39 +202,21 @@ def module_policy(self):
200202 policy [GPT2Model ].method_replacement = {
201203 "forward" : get_gpt_model_forward_for_flash_attn (self .shard_config )
202204 }
205+ if self .enable_bias_gelu_fused :
206+ self .append_or_create_method_replacement (
207+ description = {
208+ "forward" : get_jit_fused_gpt2_mlp_forward (),
209+ },
210+ policy = policy ,
211+ target_key = GPT2MLP ,
212+ )
203213
204214 if sp_mode is not None :
205215 policy [GPT2Model ].method_replacement = {"forward" : gpt2_sequence_parallel_forward_fn (self .shard_config )}
206216
207217 return policy
208218
209219 def postprocess (self ):
210- import torch
211-
212- from colossalai .shardformer ._utils import setattr_
213-
214- def bias_gelu_substitute_gpt2 (module ):
215- target_linear = None
216- for name , child in module .named_children ():
217- bias_gelu_substitute_gpt2 (child )
218- if name == "c_fc" and isinstance (child , col_nn .GPT2FusedLinearConv1D_Col ):
219- target_linear = child
220- elif target_linear is not None :
221- if name == "act" :
222- replace_sub_module = Bias_Gelu (target_linear .bias )
223- target_linear .bias = None
224- setattr_ (module , "act" , replace_sub_module )
225-
226- target_linear = None
227-
228- def trial (module ):
229- if torch .distributed .get_rank () == 0 :
230- print (module .__class__ .__name__ )
231- for name , child in module .named_children ():
232- trial (child )
233-
234- bias_gelu_substitute_gpt2 (self .model )
235- trial (self .model )
236220 return self .model
237221
238222 def get_held_layers (self ) -> List [nn .Module ]:
0 commit comments