Skip to content

Commit c17dfec

Browse files
committed
added varlen and generation support
1 parent c5fceb2 commit c17dfec

File tree

3 files changed

+348
-150
lines changed

3 files changed

+348
-150
lines changed

fla/layers/deltaformer.py

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from einops import rearrange, repeat
1111
from transformers.utils import logging
1212

13-
from fla.layers.utils import pad_input, unpad_input
13+
from fla.layers.utils import get_unpad_data, pad_input, unpad_input
1414
from fla.modules import RMSNorm
1515
from fla.ops.deltaformer import delta_pre_attn
1616

@@ -121,27 +121,87 @@ def forward(
121121
if self.qk_norm:
122122
q, k = self.q_norm(q), self.k_norm(k)
123123

124-
if attention_mask is not None:
125-
# Use varlen FlashAttention path. Pre-attention currently supports fixed length only → fallback by padding.
126-
q_full = q
127-
k_full = k
128-
v_full = v
129-
beta_full = beta
124+
cache_has_content = past_key_values is not None and past_key_values.get_seq_length(self.layer_idx) > 0
125+
126+
if not cache_has_content or q_len > 1:
127+
# Prefill: compute U for current block
128+
if attention_mask is not None:
129+
_, cu_seqlens_k, _ = get_unpad_data(attention_mask)
130+
u = delta_pre_attn(
131+
rearrange(q, 'b t h d -> b h t d'),
132+
rearrange(k, 'b t h d -> b h t d'),
133+
rearrange(v, 'b t h d -> b h t d'),
134+
beta,
135+
cu_seqlens=cu_seqlens_k,
136+
)
137+
else:
138+
u = delta_pre_attn(
139+
rearrange(q, 'b t h d -> b h t d'),
140+
rearrange(k, 'b t h d -> b h t d'),
141+
rearrange(v, 'b t h d -> b h t d'),
142+
beta,
143+
)
144+
u = rearrange(u, 'b h t d -> b t h d')
145+
146+
k_eff, u_eff = k, u
147+
if use_cache and past_key_values is not None:
148+
k_flat = k.flatten(-2, -1)
149+
u_flat = u.flatten(-2, -1)
150+
k_cached_flat, u_cached_flat = past_key_values.update(
151+
attn_state=(k_flat, u_flat),
152+
layer_idx=self.layer_idx,
153+
offset=q_len,
154+
)['attn_state']
155+
if cache_has_content:
156+
k_eff = rearrange(k_cached_flat, 'b t (h d) -> b t h d', h=self.num_kv_heads * self.num_kv_groups)
157+
u_eff = rearrange(u_cached_flat, 'b t (h d) -> b t h d', h=self.num_heads)
130158
else:
131-
q_full, k_full, v_full, beta_full = q, k, v, beta
132-
133-
# Compute u via DeltaFormer pre-attention (fixed-length kernel).
134-
u = delta_pre_attn(
135-
rearrange(q_full, 'b t h d -> b h t d'),
136-
rearrange(k_full, 'b t h d -> b h t d'),
137-
rearrange(v_full, 'b t h d -> b h t d'),
138-
beta_full,
139-
)
140-
u = rearrange(u, 'b h t d -> b t h d')
141-
142-
# Second stage: standard FlashAttention but using u as values
159+
state = past_key_values[self.layer_idx]
160+
k_cached_flat, u_cached_flat = state['attn_state']
161+
T_prev = k_cached_flat.shape[1]
162+
k_prev = rearrange(k_cached_flat, 'b t (h d) -> b t h d', h=self.num_kv_heads * self.num_kv_groups)
163+
u_prev = rearrange(u_cached_flat, 'b t (h d) -> b t h d', h=self.num_heads)
164+
165+
if attention_mask is not None:
166+
attn_mask_prev = attention_mask[:, :T_prev]
167+
q_padded, (k_padded_prev, u_padded_prev), indices_q, cu_seqlens, max_seq_lens = unpad_input(
168+
q,
169+
(k_prev, u_prev),
170+
attn_mask_prev,
171+
q_len,
172+
)
173+
cu_seqlens_q, cu_seqlens_k = cu_seqlens
174+
max_seqlen_q, max_seqlen_k = max_seq_lens
175+
s = flash_attn_varlen_func(
176+
q_padded, k_padded_prev, u_padded_prev,
177+
cu_seqlens_q=cu_seqlens_q,
178+
cu_seqlens_k=cu_seqlens_k,
179+
max_seqlen_q=max_seqlen_q,
180+
max_seqlen_k=max_seqlen_k,
181+
causal=False,
182+
window_size=(-1, -1)
183+
)
184+
s = pad_input(s, indices_q, batch_size, q_len)
185+
else:
186+
s = flash_attn_func(q, k_prev, u_prev, causal=False, window_size=(-1, -1))
187+
188+
u_cur = v - rearrange(beta, 'b h t -> b t h 1') * s
189+
k_eff = torch.cat([k_prev, k], dim=1)
190+
u_eff = torch.cat([u_prev, u_cur], dim=1)
191+
192+
past_key_values.update(
193+
attn_state=(k_eff.flatten(-2, -1), u_eff.flatten(-2, -1)),
194+
layer_idx=self.layer_idx,
195+
offset=q_len,
196+
)
197+
143198
if attention_mask is not None:
144-
q_padded, (k_padded, u_padded), indices_q, cu_seqlens, max_seq_lens = unpad_input(q, (k, u), attention_mask, q_len)
199+
q_padded, (k_padded, u_padded), indices_q, cu_seqlens, max_seq_lens = unpad_input(
200+
q,
201+
(k_eff, u_eff),
202+
attention_mask,
203+
q_len,
204+
)
145205
cu_seqlens_q, cu_seqlens_k = cu_seqlens
146206
max_seqlen_q, max_seqlen_k = max_seq_lens
147207
o = flash_attn_varlen_func(
@@ -155,7 +215,7 @@ def forward(
155215
)
156216
o = pad_input(o, indices_q, batch_size, q_len)
157217
else:
158-
o = flash_attn_func(q, k, u, causal=True, window_size=(-1, -1))
218+
o = flash_attn_func(q, k_eff, u_eff, causal=True, window_size=(-1, -1))
159219

160220
o = o.reshape(batch_size, q_len, -1)
161221
o = self.o_proj(o)

0 commit comments

Comments
 (0)