@@ -296,23 +296,31 @@ def forward(
296296 cu_seqlens : torch .Tensor ,
297297 rotary_pos_emb : Optional [torch .Tensor ] = None ,
298298 position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] = None ,
299- ** kwargs : Unpack [FlashAttentionKwargs ],
299+ attention_mask : Optional [torch .Tensor ] = None ,
300+ ** kwargs ,
300301 ) -> torch .Tensor :
301302 seq_length = hidden_states .shape [0 ]
302303 query_states , key_states , value_states = (
303304 self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
304305 )
305-
306- cos , sin = position_embeddings
306+ if position_embeddings is None :
307+ logger .warning_once (
308+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
309+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
310+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
311+ "removed and `position_embeddings` will be mandatory."
312+ )
313+ emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
314+ cos = emb .cos ()
315+ sin = emb .sin ()
316+ else :
317+ cos , sin = position_embeddings
307318 query_states , key_states = apply_rotary_pos_emb_vision (query_states , key_states , cos , sin )
308319
309320 query_states = query_states .transpose (0 , 1 ).unsqueeze (0 )
310321 key_states = key_states .transpose (0 , 1 ).unsqueeze (0 )
311322 value_states = value_states .transpose (0 , 1 ).unsqueeze (0 )
312-
313- attention_mask = torch .zeros ([1 , 1 , seq_length , seq_length ], device = query_states .device , dtype = torch .bool )
314- for i in range (1 , len (cu_seqlens )):
315- attention_mask [..., cu_seqlens [i - 1 ] : cu_seqlens [i ], cu_seqlens [i - 1 ] : cu_seqlens [i ]] = True
323+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
316324
317325 attention_interface : Callable = eager_attention_forward
318326 if self .config ._attn_implementation != "eager" :
@@ -323,13 +331,17 @@ def forward(
323331 query_states ,
324332 key_states ,
325333 value_states ,
326- attention_mask ,
334+ attention_mask = attention_mask ,
327335 dropout = 0.0 if not self .training else self .attention_dropout ,
328336 scaling = self .scaling ,
329- is_causal = self .is_causal ,
337+ cu_seq_lens_q = cu_seqlens , # pass cu seq lens for FA2
338+ cu_seq_lens_k = cu_seqlens ,
339+ max_length_q = max_seqlen ,
340+ max_length_k = max_seqlen ,
341+ is_causal = False ,
330342 ** kwargs ,
331343 )
332- attn_output = attn_output . squeeze ( 0 )
344+
333345 attn_output = attn_output .reshape (seq_length , - 1 ).contiguous ()
334346 attn_output = self .proj (attn_output )
335347 return attn_output
0 commit comments