@@ -36,13 +36,20 @@ def config_sanity_check(self):
3636
3737 def preprocess (self ):
3838 self .tie_weight = self .tie_weight_check ()
39+ self .origin_attn_implement = self .model .config ._attn_implementation
3940 return self .model
4041
4142 def module_policy (self ) -> Dict [Union [str , nn .Module ], ModulePolicyDescription ]:
42- from transformers .models .llama .modeling_llama import LlamaAttention , LlamaDecoderLayer , LlamaModel
43-
43+ from transformers .models .llama .modeling_llama import LlamaAttention , LlamaFlashAttention2 , LlamaSdpaAttention , LlamaDecoderLayer , LlamaModel
44+ ATTN_IMPLEMENTATION = {
45+ "eager" : LlamaAttention ,
46+ "flash_attention_2" : LlamaFlashAttention2 ,
47+ "sdpa" : LlamaSdpaAttention ,
48+ }
4449 policy = {}
4550
51+ attn_cls = ATTN_IMPLEMENTATION [self .origin_attn_implement ]
52+
4653 embedding_cls = None
4754 if self .shard_config .enable_tensor_parallelism :
4855 embedding_cls = VocabParallelEmbedding1D
@@ -93,7 +100,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
93100 "forward" : get_llama_seq_parallel_attention_forward (sp_mode , sp_size , sp_group ),
94101 },
95102 policy = policy ,
96- target_key = LlamaAttention ,
103+ target_key = attn_cls ,
97104 )
98105 elif sp_mode == "all_to_all" :
99106 decoder_attribute_replacement = {
@@ -102,15 +109,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
102109 if getattr (self .model .config , "num_key_value_heads" , False ):
103110 decoder_attribute_replacement ["num_key_value_heads" ] = self .model .config .num_key_value_heads // sp_size
104111
105- policy [LlamaAttention ] = ModulePolicyDescription (
112+ policy [attn_cls ] = ModulePolicyDescription (
106113 attribute_replacement = decoder_attribute_replacement ,
107114 )
108115 self .append_or_create_method_replacement (
109116 description = {
110117 "forward" : get_llama_seq_parallel_attention_forward (sp_mode , sp_size , sp_group ),
111118 },
112119 policy = policy ,
113- target_key = LlamaAttention ,
120+ target_key = attn_cls ,
114121 )
115122 self .append_or_create_method_replacement (
116123 description = {
@@ -221,7 +228,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
221228 "forward" : get_llama_flash_attention_forward (self .shard_config , sp_mode , sp_group , sp_size ),
222229 },
223230 policy = policy ,
224- target_key = LlamaAttention ,
231+ target_key = attn_cls ,
225232 )
226233 if self .pipeline_stage_manager is None :
227234 # replace llama model forward method
0 commit comments