Skip to content

Commit c5e29e4

Browse files
committed
update transformers
update transformers fix
1 parent 2e2d1c1 commit c5e29e4

File tree

6 files changed

+37
-19
lines changed

6 files changed

+37
-19
lines changed

colossalai/shardformer/modeling/mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,4 @@ def forward(
220220

221221
return attn_output, attn_weights, past_key_value
222222

223-
return forward
223+
return forward

colossalai/shardformer/policies/llama.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,20 @@ def config_sanity_check(self):
3636

3737
def preprocess(self):
3838
self.tie_weight = self.tie_weight_check()
39+
self.origin_attn_implement = self.model.config._attn_implementation
3940
return self.model
4041

4142
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
42-
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
43-
43+
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention, LlamaDecoderLayer, LlamaModel
44+
ATTN_IMPLEMENTATION = {
45+
"eager": LlamaAttention,
46+
"flash_attention_2": LlamaFlashAttention2,
47+
"sdpa": LlamaSdpaAttention,
48+
}
4449
policy = {}
4550

51+
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
52+
4653
embedding_cls = None
4754
if self.shard_config.enable_tensor_parallelism:
4855
embedding_cls = VocabParallelEmbedding1D
@@ -93,7 +100,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
93100
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
94101
},
95102
policy=policy,
96-
target_key=LlamaAttention,
103+
target_key=attn_cls,
97104
)
98105
elif sp_mode == "all_to_all":
99106
decoder_attribute_replacement = {
@@ -102,15 +109,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
102109
if getattr(self.model.config, "num_key_value_heads", False):
103110
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
104111

105-
policy[LlamaAttention] = ModulePolicyDescription(
112+
policy[attn_cls] = ModulePolicyDescription(
106113
attribute_replacement=decoder_attribute_replacement,
107114
)
108115
self.append_or_create_method_replacement(
109116
description={
110117
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
111118
},
112119
policy=policy,
113-
target_key=LlamaAttention,
120+
target_key=attn_cls,
114121
)
115122
self.append_or_create_method_replacement(
116123
description={
@@ -221,7 +228,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
221228
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
222229
},
223230
policy=policy,
224-
target_key=LlamaAttention,
231+
target_key=attn_cls,
225232
)
226233
if self.pipeline_stage_manager is None:
227234
# replace llama model forward method

colossalai/shardformer/policies/mistral.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,21 @@ def config_sanity_check(self):
2626

2727
def preprocess(self):
2828
self.tie_weight = self.tie_weight_check()
29+
self.origin_attn_implement = self.model.config._attn_implementation
2930
return self.model
3031

3132
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
32-
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel
33+
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralFlashAttention2, MistralDecoderLayer, MistralModel
34+
35+
ATTN_IMPLEMENTATION = {
36+
"eager": MistralAttention,
37+
"flash_attention_2": MistralFlashAttention2,
38+
}
3339

3440
policy = {}
3541

42+
attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
43+
3644
embedding_cls = None
3745
if self.shard_config.enable_tensor_parallelism:
3846
embedding_cls = VocabParallelEmbedding1D
@@ -128,10 +136,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
128136
if self.shard_config.enable_flash_attention:
129137
self.append_or_create_method_replacement(
130138
description={
131-
"forward": get_mistral_flash_attention_forward(),
139+
"forward": get_mistral_flash_attention_forward(self.shard_config),
132140
},
133141
policy=policy,
134-
target_key=MistralAttention,
142+
target_key=attn_cls,
135143
)
136144

137145
return policy
@@ -143,9 +151,6 @@ def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict)
143151
method_replacement = {"forward": partial(new_forward)}
144152
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
145153

146-
def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
147-
method_replacement = {"forward": partial(new_forward)}
148-
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
149154

150155

151156
class MistralModelPolicy(MistralPolicy):
@@ -155,7 +160,6 @@ def __init__(self) -> None:
155160
def module_policy(self):
156161
policy = super().module_policy()
157162
from transformers.models.mistral.modeling_mistral import MistralModel
158-
159163
self.set_forward(model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy)
160164
return policy
161165

colossalai/shardformer/policies/opt.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,21 @@ def config_sanity_check(self):
4444

4545
def preprocess(self):
4646
self.tie_weight = self.tie_weight_check()
47+
self.origin_attn_implement = self.model.config._attn_implementation
4748
return self.model
4849

4950
def module_policy(self):
50-
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
51+
from transformers.models.opt.modeling_opt import OPTAttention, OptFlashAttention2, OPTDecoder, OPTDecoderLayer
52+
53+
ATTN_IMPLEMENTATION = {
54+
"eager": OPTAttention,
55+
"flash_attention_2": OptFlashAttention2,
56+
}
5157

5258
policy = {}
5359

60+
attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
61+
5462
embedding_cls = None
5563
if self.shard_config.enable_tensor_parallelism:
5664
embedding_cls = VocabParallelEmbedding1D
@@ -81,7 +89,7 @@ def module_policy(self):
8189
]
8290
)
8391

84-
policy[OPTAttention] = ModulePolicyDescription(
92+
policy[attn_cls] = ModulePolicyDescription(
8593
attribute_replacement={
8694
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
8795
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
@@ -151,7 +159,7 @@ def module_policy(self):
151159
"forward": get_opt_flash_attention_forward(self.shard_config),
152160
},
153161
policy=policy,
154-
target_key=OPTAttention,
162+
target_key=attn_cls,
155163
)
156164
if not self.shard_config.pipeline_stage_manager:
157165
self.append_or_create_method_replacement(

tests/kit/model_zoo/transformers/llama.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def data_gen_for_casual_lm():
6565
num_attention_heads=4,
6666
max_position_embeddings=128,
6767
num_labels=16,
68-
attn_implementation="eager",
6968
)
7069

7170
if hasattr(config, "pad_token_id"):

tests/test_shardformer/test_model/test_shard_mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def check_mistral(rank, world_size, port):
156156
run_mistral_test()
157157

158158

159-
@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.")
159+
@pytest.mark.skip("something wrong with pipeline parallelism")
160160
@pytest.mark.dist
161161
@rerun_if_address_is_in_use()
162162
@clear_cache_before_run()

0 commit comments

Comments
 (0)