@@ -146,3 +146,150 @@ def forward(
146146 past_key_value ,
147147 world_size = world_size ,
148148 )
149+
150+
151+ class PatchedQwen2AttentionAscend (nn .Module ):
152+
153+ def _load_weights (self , loader , rank : int , world_size : int ,
154+ device : torch .device ):
155+ """load weights."""
156+ for mod_name in ['q_proj' , 'k_proj' , 'v_proj' ]:
157+ colwise_parallelize_linear (getattr (self , mod_name ),
158+ loader ,
159+ rank = rank ,
160+ world_size = world_size ,
161+ prefix = mod_name )
162+ for mod_name in ['o_proj' ]:
163+ rowwise_parallelize_linear (getattr (self , mod_name ),
164+ loader ,
165+ rank = rank ,
166+ world_size = world_size ,
167+ prefix = mod_name )
168+
169+ @classmethod
170+ def _distribute_output_fn (cls , outputs , ** kwargs ):
171+ """Distribution output hook."""
172+ dist .all_reduce (outputs [0 ])
173+ return outputs
174+
175+ def _contiguous_batching_forward_impl (
176+ self ,
177+ hidden_states : torch .Tensor ,
178+ position_ids : Optional [torch .LongTensor ] = None ,
179+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
180+ world_size : int = 1 ,
181+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ],
182+ Optional [Tuple [torch .Tensor ]]]:
183+ """Rewrite implementation of forward.
184+
185+ Add continuous batching support. Add paged attention support. TP
186+ support.
187+ """
188+ context = self .context .context
189+ kv_seq_length = context .kv_seq_length
190+ q_seq_length = context .q_seq_length
191+ q_start_loc = context .q_start_loc
192+ block_offsets = context .block_offsets
193+ max_q_seq_length = context .max_q_seq_length
194+ max_kv_seq_length = context .max_kv_seq_length
195+
196+ num_heads = self .num_heads // world_size
197+ num_kv_heads = self .num_key_value_heads // world_size
198+ head_dim = self .head_dim
199+ hidden_size = num_heads * head_dim
200+
201+ def __qkv_proj (hidden_states ):
202+ """qkv proj."""
203+ query_states = self .q_proj (hidden_states )
204+ key_states = self .k_proj (hidden_states )
205+ value_states = self .v_proj (hidden_states )
206+
207+ return query_states , key_states , value_states
208+
209+ def __rotary_emb_fn (query_states , key_states , value_states ):
210+ if hasattr (self , 'rotary_emb' ):
211+ cos , sin = self .rotary_emb (value_states ,
212+ seq_len = max_kv_seq_length )
213+ query_states , key_states = apply_rotary_pos_emb (
214+ query_states ,
215+ key_states ,
216+ cos ,
217+ sin ,
218+ position_ids ,
219+ context .position_ids_1d ,
220+ context = context )
221+ return query_states , key_states , value_states
222+
223+ query_states , key_states , value_states = __qkv_proj (hidden_states )
224+
225+ query_states = query_states .view (- 1 , num_heads , head_dim )
226+ key_states = key_states .view (- 1 , num_kv_heads , head_dim )
227+ value_states = value_states .view (- 1 , num_kv_heads , head_dim )
228+
229+ query_states , key_states , value_states = __rotary_emb_fn (
230+ query_states , key_states , value_states )
231+
232+ fill_kv_cache (
233+ key_states ,
234+ value_states ,
235+ past_key_value [0 ],
236+ past_key_value [1 ],
237+ q_start_loc ,
238+ q_seq_length ,
239+ kv_seq_length = kv_seq_length ,
240+ max_q_seq_length = max_q_seq_length ,
241+ block_offsets = block_offsets ,
242+ context = context ,
243+ )
244+
245+ attn_output = query_states
246+
247+ use_sliding_windows = (getattr (self .config , 'sliding_window' , None )
248+ is not None and self .config .use_sliding_window )
249+ window_size = self .config .sliding_window
250+ if not use_sliding_windows :
251+ window_size = - 1
252+ paged_attention_fwd (
253+ query_states ,
254+ key_states ,
255+ value_states ,
256+ past_key_value [0 ],
257+ past_key_value [1 ],
258+ attn_output ,
259+ block_offsets ,
260+ q_start_loc = q_start_loc ,
261+ q_seqlens = q_seq_length ,
262+ kv_seqlens = kv_seq_length ,
263+ max_seqlen = max_q_seq_length ,
264+ window_size = window_size ,
265+ context = context ,
266+ )
267+
268+ attn_output = attn_output .reshape (* hidden_states .shape [:- 1 ],
269+ hidden_size )
270+
271+ attn_output = self .o_proj (attn_output )
272+
273+ return attn_output , None , past_key_value
274+
275+ def forward (
276+ self ,
277+ hidden_states : torch .Tensor ,
278+ attention_mask : Optional [torch .Tensor ] = None ,
279+ position_ids : Optional [torch .LongTensor ] = None ,
280+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
281+ output_attentions : bool = False ,
282+ use_cache : bool = False ,
283+ ** kwargs ,
284+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ],
285+ Optional [Tuple [torch .Tensor ]]]:
286+ """Rewrite of forward."""
287+ world_size = 1
288+ if dist .is_initialized ():
289+ world_size = dist .get_world_size ()
290+ return self ._contiguous_batching_forward_impl (
291+ hidden_states ,
292+ position_ids ,
293+ past_key_value ,
294+ world_size = world_size ,
295+ )
0 commit comments