Skip to content

Commit 58551f9

Browse files
committed
support gelu_bias_fused for gpt2
fix fix fix
1 parent 6391b86 commit 58551f9

File tree

8 files changed

+110
-29
lines changed

8 files changed

+110
-29
lines changed

colossalai/shardformer/modeling/bert.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,3 +1287,16 @@ def forward(
12871287
)
12881288

12891289
return forward
1290+
1291+
1292+
def get_jit_fused_bert_intermediate_forward():
1293+
from transformers.models.bert.modeling_bert import BertIntermediate
1294+
1295+
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
1296+
1297+
def forward(self: BertIntermediate, hidden_states: torch.Tensor) -> torch.Tensor:
1298+
hidden_states, bias = self.dense(hidden_states)
1299+
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
1300+
return hidden_states
1301+
1302+
return forward

colossalai/shardformer/modeling/blip2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,17 @@ def forward(
129129
return hidden_states
130130

131131
return forward
132+
133+
134+
def get_jit_fused_blip2_mlp_forward():
135+
from transformers.models.blip_2.modeling_blip_2 import Blip2MLP
136+
137+
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
138+
139+
def forward(self: Blip2MLP, hidden_states: torch.Tensor) -> torch.Tensor:
140+
hidden_states, bias = self.fc1(hidden_states)
141+
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
142+
hidden_states = self.fc2(hidden_states)
143+
return hidden_states
144+
145+
return forward

colossalai/shardformer/modeling/gpt2.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,3 +1310,18 @@ def forward(
13101310
)
13111311

13121312
return forward
1313+
1314+
1315+
def get_jit_fused_gpt2_mlp_forward():
1316+
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
1317+
1318+
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
1319+
1320+
def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
1321+
hidden_states, bias = self.c_fc(hidden_states)
1322+
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
1323+
hidden_states = self.c_proj(hidden_states)
1324+
hidden_states = self.dropout(hidden_states)
1325+
return hidden_states
1326+
1327+
return forward

colossalai/shardformer/modeling/vit.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,15 @@ def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Te
372372
return hidden_states
373373

374374
return forward
375+
376+
377+
def get_jit_fused_vit_intermediate_forward():
378+
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
379+
380+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
381+
hidden_states, bias = self.dense(hidden_states)
382+
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
383+
384+
return hidden_states
385+
386+
return forward

colossalai/shardformer/policies/bert.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
BertPipelineForwards,
1313
bert_sequence_parallel_forward_fn,
1414
get_bert_flash_attention_forward,
15+
get_jit_fused_bert_intermediate_forward,
1516
get_jit_fused_bert_output_forward,
1617
get_jit_fused_bert_self_output_forward,
1718
)
@@ -38,11 +39,13 @@ def config_sanity_check(self):
3839

3940
def preprocess(self):
4041
self.tie_weight = self.tie_weight_check()
42+
self.enable_bias_gelu_fused = self.model.config.hidden_act == "gelu"
4143
return self.model
4244

4345
def module_policy(self):
4446
from transformers.models.bert.modeling_bert import (
4547
BertEmbeddings,
48+
BertIntermediate,
4649
BertLayer,
4750
BertModel,
4851
BertOutput,
@@ -131,6 +134,7 @@ def module_policy(self):
131134
kwargs={
132135
"seq_parallel_mode": sp_mode,
133136
"overlap": overlap,
137+
"skip_bias_add": self.enable_bias_gelu_fused,
134138
},
135139
),
136140
SubModuleReplacementDescription(
@@ -231,6 +235,14 @@ def module_policy(self):
231235
policy=policy,
232236
target_key=BertOutput,
233237
)
238+
if self.enable_bias_gelu_fused:
239+
self.append_or_create_method_replacement(
240+
description={
241+
"forward": get_jit_fused_bert_intermediate_forward(),
242+
},
243+
policy=policy,
244+
target_key=BertIntermediate,
245+
)
234246

235247
return policy
236248

colossalai/shardformer/policies/blip2.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ..modeling.blip2 import (
44
forward_fn,
55
get_blip2_flash_attention_forward,
6+
get_jit_fused_blip2_mlp_forward,
67
get_jit_fused_blip2_QFormer_output_forward,
78
get_jit_fused_blip2_QFormer_self_output_forward,
89
)
@@ -18,12 +19,14 @@ def config_sanity_check(self):
1819

1920
def preprocess(self):
2021
self.tie_weight = self.tie_weight_check()
22+
self.enable_bias_gelu_fused = self.model.config.hidden_act == "gelu"
2123
return self.model
2224

2325
def module_policy(self):
2426
from transformers.models.blip_2.modeling_blip_2 import (
2527
Blip2Attention,
2628
Blip2EncoderLayer,
29+
Blip2MLP,
2730
Blip2QFormerLayer,
2831
Blip2QFormerModel,
2932
Blip2QFormerOutput,
@@ -73,6 +76,7 @@ def module_policy(self):
7376
SubModuleReplacementDescription(
7477
suffix="mlp.fc1",
7578
target_module=col_nn.Linear1D_Col,
79+
kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
7680
),
7781
SubModuleReplacementDescription(
7882
suffix="mlp.fc2",
@@ -359,6 +363,14 @@ def module_policy(self):
359363
policy=policy,
360364
target_key=Blip2QFormerOutput,
361365
)
366+
if self.enable_bias_gelu_fused:
367+
self.append_or_create_method_replacement(
368+
description={
369+
"forward": get_jit_fused_blip2_mlp_forward(),
370+
},
371+
policy=policy,
372+
target_key=Blip2MLP,
373+
)
362374

363375
return policy
364376

colossalai/shardformer/policies/gpt2.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
import colossalai.shardformer.layer as col_nn
88

9-
from ..layer.fused_ops import Bias_Gelu
109
from ..modeling.gpt2 import (
1110
GPT2PipelineForwards,
1211
get_gpt2_flash_attention_forward,
1312
get_gpt_model_forward_for_flash_attn,
13+
get_jit_fused_gpt2_mlp_forward,
1414
get_lm_forward_with_dist_cross_entropy,
1515
gpt2_sequence_parallel_forward_fn,
1616
)
@@ -37,10 +37,11 @@ def preprocess(self):
3737
"""
3838
self.tie_weight = self.tie_weight_check()
3939
self.origin_attn_implement = self.model.config._attn_implementation
40+
self.enable_bias_gelu_fused = self.model.config.activation_function == "gelu"
4041
return self.model
4142

4243
def module_policy(self):
43-
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
44+
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
4445

4546
ATTN_IMPLEMENTATION = {
4647
"eager": GPT2Attention,
@@ -120,6 +121,7 @@ def module_policy(self):
120121
"n_fused": 1,
121122
"seq_parallel_mode": sp_mode,
122123
"overlap": overlap,
124+
"skip_bias_add": self.enable_bias_gelu_fused,
123125
},
124126
),
125127
SubModuleReplacementDescription(
@@ -200,39 +202,21 @@ def module_policy(self):
200202
policy[GPT2Model].method_replacement = {
201203
"forward": get_gpt_model_forward_for_flash_attn(self.shard_config)
202204
}
205+
if self.enable_bias_gelu_fused:
206+
self.append_or_create_method_replacement(
207+
description={
208+
"forward": get_jit_fused_gpt2_mlp_forward(),
209+
},
210+
policy=policy,
211+
target_key=GPT2MLP,
212+
)
203213

204214
if sp_mode is not None:
205215
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
206216

207217
return policy
208218

209219
def postprocess(self):
210-
import torch
211-
212-
from colossalai.shardformer._utils import setattr_
213-
214-
def bias_gelu_substitute_gpt2(module):
215-
target_linear = None
216-
for name, child in module.named_children():
217-
bias_gelu_substitute_gpt2(child)
218-
if name == "c_fc" and isinstance(child, col_nn.GPT2FusedLinearConv1D_Col):
219-
target_linear = child
220-
elif target_linear is not None:
221-
if name == "act":
222-
replace_sub_module = Bias_Gelu(target_linear.bias)
223-
target_linear.bias = None
224-
setattr_(module, "act", replace_sub_module)
225-
226-
target_linear = None
227-
228-
def trial(module):
229-
if torch.distributed.get_rank() == 0:
230-
print(module.__class__.__name__)
231-
for name, child in module.named_children():
232-
trial(child)
233-
234-
bias_gelu_substitute_gpt2(self.model)
235-
trial(self.model)
236220
return self.model
237221

238222
def get_held_layers(self) -> List[nn.Module]:

colossalai/shardformer/policies/vit.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ViTForImageClassification_pipeline_forward,
1212
ViTForMaskedImageModeling_pipeline_forward,
1313
ViTModel_pipeline_forward,
14+
get_jit_fused_vit_intermediate_forward,
1415
get_jit_fused_vit_output_forward,
1516
get_vit_flash_self_attention_forward,
1617
)
@@ -24,10 +25,17 @@ def config_sanity_check(self):
2425
pass
2526

2627
def preprocess(self):
28+
self.enable_bias_gelu_fused = self.model.config.hidden_act == "gelu"
2729
return self.model
2830

2931
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
30-
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTOutput, ViTSelfAttention
32+
from transformers.models.vit.modeling_vit import (
33+
ViTEmbeddings,
34+
ViTIntermediate,
35+
ViTLayer,
36+
ViTOutput,
37+
ViTSelfAttention,
38+
)
3139

3240
policy = {}
3341

@@ -83,6 +91,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
8391
SubModuleReplacementDescription(
8492
suffix="intermediate.dense",
8593
target_module=col_nn.Linear1D_Col,
94+
kwargs={
95+
"skip_bias_add": self.enable_bias_gelu_fused,
96+
},
8697
),
8798
SubModuleReplacementDescription(
8899
suffix="output.dense",
@@ -115,6 +126,14 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
115126
policy=policy,
116127
target_key=ViTOutput,
117128
)
129+
if self.enable_bias_gelu_fused:
130+
self.append_or_create_method_replacement(
131+
description={
132+
"forward": get_jit_fused_vit_intermediate_forward(),
133+
},
134+
policy=policy,
135+
target_key=ViTIntermediate,
136+
)
118137
return policy
119138

120139
def new_model_class(self):

0 commit comments

Comments
 (0)