Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions colossalai/shardformer/modeling/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,3 +1287,16 @@ def forward(
)

return forward


def get_jit_fused_bert_intermediate_forward():
from transformers.models.bert.modeling_bert import BertIntermediate

from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction

def forward(self: BertIntermediate, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, bias = self.dense(hidden_states)
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
return hidden_states

return forward
14 changes: 14 additions & 0 deletions colossalai/shardformer/modeling/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,17 @@ def forward(
return hidden_states

return forward


def get_jit_fused_blip2_mlp_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2MLP

from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction

def forward(self: Blip2MLP, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, bias = self.fc1(hidden_states)
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
hidden_states = self.fc2(hidden_states)
return hidden_states

return forward
15 changes: 15 additions & 0 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,3 +1310,18 @@ def forward(
)

return forward


def get_jit_fused_gpt2_mlp_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP

from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction

def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states, bias = self.c_fc(hidden_states)
hidden_states = JitGeLUFunction.apply(hidden_states, bias)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states

return forward
12 changes: 12 additions & 0 deletions colossalai/shardformer/modeling/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,15 @@ def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Te
return hidden_states

return forward


def get_jit_fused_vit_intermediate_forward():
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, bias = self.dense(hidden_states)
hidden_states = JitGeLUFunction.apply(hidden_states, bias)

return hidden_states

return forward
12 changes: 12 additions & 0 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_flash_attention_forward,
get_jit_fused_bert_intermediate_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
)
Expand All @@ -38,11 +39,13 @@ def config_sanity_check(self):

def preprocess(self):
self.tie_weight = self.tie_weight_check()
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
return self.model

def module_policy(self):
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertIntermediate,
BertLayer,
BertModel,
BertOutput,
Expand Down Expand Up @@ -131,6 +134,7 @@ def module_policy(self):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
},
),
SubModuleReplacementDescription(
Expand All @@ -153,6 +157,14 @@ def module_policy(self):
),
]
)
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_bert_intermediate_forward(),
},
policy=policy,
target_key=BertIntermediate,
)

if sp_mode == "split_gather":
self.append_or_create_method_replacement(
Expand Down
14 changes: 14 additions & 0 deletions colossalai/shardformer/policies/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..modeling.blip2 import (
forward_fn,
get_blip2_flash_attention_forward,
get_jit_fused_blip2_mlp_forward,
get_jit_fused_blip2_QFormer_output_forward,
get_jit_fused_blip2_QFormer_self_output_forward,
)
Expand All @@ -18,12 +19,16 @@ def config_sanity_check(self):

def preprocess(self):
self.tie_weight = self.tie_weight_check()
self.enable_bias_gelu_fused = (
self.shard_config.enable_jit_fused and self.model.config.vision_config.hidden_act == "gelu"
)
return self.model

def module_policy(self):
from transformers.models.blip_2.modeling_blip_2 import (
Blip2Attention,
Blip2EncoderLayer,
Blip2MLP,
Blip2QFormerLayer,
Blip2QFormerModel,
Blip2QFormerOutput,
Expand Down Expand Up @@ -73,6 +78,7 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="mlp.fc1",
target_module=col_nn.Linear1D_Col,
kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
),
SubModuleReplacementDescription(
suffix="mlp.fc2",
Expand Down Expand Up @@ -201,6 +207,14 @@ def module_policy(self):
)

policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_blip2_mlp_forward(),
},
policy=policy,
target_key=Blip2MLP,
)

if embedding_cls is not None:
self.append_or_create_submodule_replacement(
Expand Down
15 changes: 14 additions & 1 deletion colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
GPT2PipelineForwards,
get_gpt2_flash_attention_forward,
get_gpt_model_forward_for_flash_attn,
get_jit_fused_gpt2_mlp_forward,
get_lm_forward_with_dist_cross_entropy,
gpt2_sequence_parallel_forward_fn,
)
Expand All @@ -36,10 +37,13 @@ def preprocess(self):
"""
self.tie_weight = self.tie_weight_check()
self.origin_attn_implement = self.model.config._attn_implementation
self.enable_bias_gelu_fused = (
self.shard_config.enable_jit_fused and self.model.config.activation_function == "gelu"
)
return self.model

def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model

ATTN_IMPLEMENTATION = {
"eager": GPT2Attention,
Expand Down Expand Up @@ -119,6 +123,7 @@ def module_policy(self):
"n_fused": 1,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
},
),
SubModuleReplacementDescription(
Expand All @@ -142,6 +147,14 @@ def module_policy(self):
),
],
)
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_gpt2_mlp_forward(),
},
policy=policy,
target_key=GPT2MLP,
)
if embedding_cls is not None:
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
self.append_or_create_submodule_replacement(
Expand Down
22 changes: 21 additions & 1 deletion colossalai/shardformer/policies/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ViTForImageClassification_pipeline_forward,
ViTForMaskedImageModeling_pipeline_forward,
ViTModel_pipeline_forward,
get_jit_fused_vit_intermediate_forward,
get_jit_fused_vit_output_forward,
get_vit_flash_self_attention_forward,
)
Expand All @@ -24,10 +25,17 @@ def config_sanity_check(self):
pass

def preprocess(self):
self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTOutput, ViTSelfAttention
from transformers.models.vit.modeling_vit import (
ViTEmbeddings,
ViTIntermediate,
ViTLayer,
ViTOutput,
ViTSelfAttention,
)

policy = {}

Expand Down Expand Up @@ -83,6 +91,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
kwargs={
"skip_bias_add": self.enable_bias_gelu_fused,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
Expand All @@ -94,6 +105,14 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
),
],
)
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_vit_intermediate_forward(),
},
policy=policy,
target_key=ViTIntermediate,
)

# use flash attention
if self.shard_config.enable_flash_attention:
Expand All @@ -115,6 +134,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy=policy,
target_key=ViTOutput,
)

return policy

def new_model_class(self):
Expand Down