1010from  einops  import  rearrange , repeat 
1111from  transformers .utils  import  logging 
1212
13- from  fla .layers .utils  import  pad_input , unpad_input 
13+ from  fla .layers .utils  import  get_unpad_data ,  pad_input , unpad_input 
1414from  fla .modules  import  RMSNorm 
1515from  fla .ops .deltaformer  import  delta_pre_attn 
1616
@@ -121,27 +121,87 @@ def forward(
121121        if  self .qk_norm :
122122            q , k  =  self .q_norm (q ), self .k_norm (k )
123123
124-         if  attention_mask  is  not None :
125-             # Use varlen FlashAttention path. Pre-attention currently supports fixed length only → fallback by padding. 
126-             q_full  =  q 
127-             k_full  =  k 
128-             v_full  =  v 
129-             beta_full  =  beta 
124+         cache_has_content  =  past_key_values  is  not None  and  past_key_values .get_seq_length (self .layer_idx ) >  0 
125+ 
126+         if  not  cache_has_content  or  q_len  >  1 :
127+             # Prefill: compute U for current block 
128+             if  attention_mask  is  not None :
129+                 _ , cu_seqlens_k , _  =  get_unpad_data (attention_mask )
130+                 u  =  delta_pre_attn (
131+                     rearrange (q , 'b t h d -> b h t d' ),
132+                     rearrange (k , 'b t h d -> b h t d' ),
133+                     rearrange (v , 'b t h d -> b h t d' ),
134+                     beta ,
135+                     cu_seqlens = cu_seqlens_k ,
136+                 )
137+             else :
138+                 u  =  delta_pre_attn (
139+                     rearrange (q , 'b t h d -> b h t d' ),
140+                     rearrange (k , 'b t h d -> b h t d' ),
141+                     rearrange (v , 'b t h d -> b h t d' ),
142+                     beta ,
143+                 )
144+             u  =  rearrange (u , 'b h t d -> b t h d' )
145+ 
146+             k_eff , u_eff  =  k , u 
147+             if  use_cache  and  past_key_values  is  not None :
148+                 k_flat  =  k .flatten (- 2 , - 1 )
149+                 u_flat  =  u .flatten (- 2 , - 1 )
150+                 k_cached_flat , u_cached_flat  =  past_key_values .update (
151+                     attn_state = (k_flat , u_flat ),
152+                     layer_idx = self .layer_idx ,
153+                     offset = q_len ,
154+                 )['attn_state' ]
155+                 if  cache_has_content :
156+                     k_eff  =  rearrange (k_cached_flat , 'b t (h d) -> b t h d' , h = self .num_kv_heads  *  self .num_kv_groups )
157+                     u_eff  =  rearrange (u_cached_flat , 'b t (h d) -> b t h d' , h = self .num_heads )
130158        else :
131-             q_full , k_full , v_full , beta_full  =  q , k , v , beta 
132- 
133-         # Compute u via DeltaFormer pre-attention (fixed-length kernel). 
134-         u  =  delta_pre_attn (
135-             rearrange (q_full , 'b t h d -> b h t d' ),
136-             rearrange (k_full , 'b t h d -> b h t d' ),
137-             rearrange (v_full , 'b t h d -> b h t d' ),
138-             beta_full ,
139-         )
140-         u  =  rearrange (u , 'b h t d -> b t h d' )
141- 
142-         # Second stage: standard FlashAttention but using u as values 
159+             state  =  past_key_values [self .layer_idx ]
160+             k_cached_flat , u_cached_flat  =  state ['attn_state' ]
161+             T_prev  =  k_cached_flat .shape [1 ]
162+             k_prev  =  rearrange (k_cached_flat , 'b t (h d) -> b t h d' , h = self .num_kv_heads  *  self .num_kv_groups )
163+             u_prev  =  rearrange (u_cached_flat , 'b t (h d) -> b t h d' , h = self .num_heads )
164+ 
165+             if  attention_mask  is  not None :
166+                 attn_mask_prev  =  attention_mask [:, :T_prev ]
167+                 q_padded , (k_padded_prev , u_padded_prev ), indices_q , cu_seqlens , max_seq_lens  =  unpad_input (
168+                     q ,
169+                     (k_prev , u_prev ),
170+                     attn_mask_prev ,
171+                     q_len ,
172+                 )
173+                 cu_seqlens_q , cu_seqlens_k  =  cu_seqlens 
174+                 max_seqlen_q , max_seqlen_k  =  max_seq_lens 
175+                 s  =  flash_attn_varlen_func (
176+                     q_padded , k_padded_prev , u_padded_prev ,
177+                     cu_seqlens_q = cu_seqlens_q ,
178+                     cu_seqlens_k = cu_seqlens_k ,
179+                     max_seqlen_q = max_seqlen_q ,
180+                     max_seqlen_k = max_seqlen_k ,
181+                     causal = False ,
182+                     window_size = (- 1 , - 1 )
183+                 )
184+                 s  =  pad_input (s , indices_q , batch_size , q_len )
185+             else :
186+                 s  =  flash_attn_func (q , k_prev , u_prev , causal = False , window_size = (- 1 , - 1 ))
187+ 
188+             u_cur  =  v  -  rearrange (beta , 'b h t -> b t h 1' ) *  s 
189+             k_eff  =  torch .cat ([k_prev , k ], dim = 1 )
190+             u_eff  =  torch .cat ([u_prev , u_cur ], dim = 1 )
191+ 
192+             past_key_values .update (
193+                 attn_state = (k_eff .flatten (- 2 , - 1 ), u_eff .flatten (- 2 , - 1 )),
194+                 layer_idx = self .layer_idx ,
195+                 offset = q_len ,
196+             )
197+ 
143198        if  attention_mask  is  not None :
144-             q_padded , (k_padded , u_padded ), indices_q , cu_seqlens , max_seq_lens  =  unpad_input (q , (k , u ), attention_mask , q_len )
199+             q_padded , (k_padded , u_padded ), indices_q , cu_seqlens , max_seq_lens  =  unpad_input (
200+                 q ,
201+                 (k_eff , u_eff ),
202+                 attention_mask ,
203+                 q_len ,
204+             )
145205            cu_seqlens_q , cu_seqlens_k  =  cu_seqlens 
146206            max_seqlen_q , max_seqlen_k  =  max_seq_lens 
147207            o  =  flash_attn_varlen_func (
@@ -155,7 +215,7 @@ def forward(
155215            )
156216            o  =  pad_input (o , indices_q , batch_size , q_len )
157217        else :
158-             o  =  flash_attn_func (q , k ,  u , causal = True , window_size = (- 1 , - 1 ))
218+             o  =  flash_attn_func (q , k_eff ,  u_eff , causal = True , window_size = (- 1 , - 1 ))
159219
160220        o  =  o .reshape (batch_size , q_len , - 1 )
161221        o  =  self .o_proj (o )
0 commit comments