Skip to content

Commit cd7e7b3

Browse files
authored
[GDN] Support fused_recurrent w/ finegrained decay
1 parent 8384557 commit cd7e7b3

File tree

1 file changed

+74
-31
lines changed

1 file changed

+74
-31
lines changed

fla/ops/gated_delta_rule/fused_recurrent.py

Lines changed: 74 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313

1414
@triton.heuristics({
15+
'USE_G': lambda args: args['g'] is not None,
16+
'USE_GK': lambda args: args['gk'] is not None,
17+
'USE_GV': lambda args: args['gv'] is not None,
1518
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
1619
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
1720
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
@@ -22,6 +25,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
2225
k,
2326
v,
2427
g,
28+
gk,
29+
gv,
2530
beta,
2631
o,
2732
h0,
@@ -36,13 +41,16 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
3641
V: tl.constexpr,
3742
BK: tl.constexpr,
3843
BV: tl.constexpr,
39-
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
40-
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
41-
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
44+
USE_G: tl.constexpr,
45+
USE_GK: tl.constexpr,
46+
USE_GV: tl.constexpr,
4247
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
43-
IS_VARLEN: tl.constexpr
48+
IS_BETA_HEADWISE: tl.constexpr,
49+
USE_INITIAL_STATE: tl.constexpr,
50+
STORE_FINAL_STATE: tl.constexpr,
51+
IS_VARLEN: tl.constexpr,
4452
):
45-
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53+
i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
4654
i_n, i_hv = i_nh // HV, i_nh % HV
4755
i_h = i_hv // (HV // H)
4856
if IS_VARLEN:
@@ -58,11 +66,16 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
5866
p_q = q + (bos * H + i_h) * K + o_k
5967
p_k = k + (bos * H + i_h) * K + o_k
6068
p_v = v + (bos * HV + i_hv) * V + o_v
69+
if USE_G:
70+
p_g = g + bos * HV + i_hv
71+
if USE_GK:
72+
p_gk = gk + (bos * HV + i_hv) * K + o_k
73+
if USE_GV:
74+
p_gv = gv + (bos * HV + i_hv) * V + o_v
6175
if IS_BETA_HEADWISE:
6276
p_beta = beta + (bos * HV + i_hv) * V + o_v
6377
else:
6478
p_beta = beta + bos * HV + i_hv
65-
p_g = g + bos * HV + i_hv
6679
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
6780

6881
mask_k = o_k < K
@@ -78,14 +91,22 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
7891
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
7992
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
8093
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
81-
b_g = tl.load(p_g).to(tl.float32)
8294

8395
if USE_QK_L2NORM_IN_KERNEL:
8496
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
8597
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
98+
8699
b_q = b_q * scale
87100
# [BK, BV]
88-
b_h *= exp(b_g)
101+
if USE_G:
102+
b_g = tl.load(p_g).to(tl.float32)
103+
b_h *= exp(b_g)
104+
if USE_GK:
105+
b_gk = tl.load(p_gk).to(tl.float32)
106+
b_h *= b_gk[:, None]
107+
if USE_GV:
108+
b_gv = tl.load(p_gv).to(tl.float32)
109+
b_h *= b_gv[None, :]
89110
# [BV]
90111
b_v -= tl.sum(b_h * b_k[:, None], 0)
91112
if IS_BETA_HEADWISE:
@@ -101,10 +122,15 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
101122

102123
p_q += H*K
103124
p_k += H*K
104-
p_o += HV*V
105125
p_v += HV*V
106-
p_g += HV
126+
if USE_G:
127+
p_g += HV
128+
if USE_GK:
129+
p_gk += HV*K
130+
if USE_GV:
131+
p_gv += HV*V
107132
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
133+
p_o += HV*V
108134

109135
if STORE_FINAL_STATE:
110136
p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
@@ -115,11 +141,13 @@ def fused_recurrent_gated_delta_rule_fwd(
115141
q: torch.Tensor,
116142
k: torch.Tensor,
117143
v: torch.Tensor,
118-
g: torch.Tensor,
119-
beta: torch.Tensor,
120-
scale: float,
121-
initial_state: torch.Tensor,
122-
output_final_state: bool,
144+
g: Optional[torch.Tensor] = None,
145+
gk: Optional[torch.Tensor] = None,
146+
gv: Optional[torch.Tensor] = None,
147+
beta: Optional[torch.Tensor] = None,
148+
scale: float = None,
149+
initial_state: torch.Tensor = None,
150+
output_final_state: bool = False,
123151
use_qk_l2norm_in_kernel: bool = False,
124152
cu_seqlens: Optional[torch.LongTensor] = None,
125153
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -138,12 +166,14 @@ def fused_recurrent_gated_delta_rule_fwd(
138166
else:
139167
final_state = None
140168

141-
grid = (NK, NV, N * HV)
169+
grid = (NV, NK, N * HV)
142170
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
143171
q=q,
144172
k=k,
145173
v=v,
146174
g=g,
175+
gk=gk,
176+
gv=gv,
147177
beta=beta,
148178
o=o,
149179
h0=initial_state,
@@ -176,25 +206,29 @@ def forward(
176206
q: torch.Tensor,
177207
k: torch.Tensor,
178208
v: torch.Tensor,
179-
g: torch.Tensor,
180-
beta: torch.Tensor,
181-
scale: float,
182-
initial_state: torch.Tensor,
183-
output_final_state: bool,
209+
g: Optional[torch.Tensor] = None,
210+
gk: Optional[torch.Tensor] = None,
211+
gv: Optional[torch.Tensor] = None,
212+
beta: Optional[torch.Tensor] = None,
213+
scale: float = None,
214+
initial_state: torch.Tensor = None,
215+
output_final_state: bool = False,
216+
use_qk_l2norm_in_kernel: bool = False,
184217
cu_seqlens: Optional[torch.LongTensor] = None,
185-
use_qk_l2norm_in_kernel: bool = False
186218
):
187219
o, final_state = fused_recurrent_gated_delta_rule_fwd(
188220
q=q,
189221
k=k,
190222
v=v,
191223
g=g,
224+
gk=gk,
225+
gv=gv,
192226
beta=beta,
193227
scale=scale,
194228
initial_state=initial_state,
195229
output_final_state=output_final_state,
196230
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
197-
cu_seqlens=cu_seqlens
231+
cu_seqlens=cu_seqlens,
198232
)
199233

200234
return o, final_state
@@ -213,13 +247,15 @@ def fused_recurrent_gated_delta_rule(
213247
q: torch.Tensor,
214248
k: torch.Tensor,
215249
v: torch.Tensor,
216-
g: torch.Tensor,
217-
beta: torch.Tensor = None,
250+
g: Optional[torch.Tensor] = None,
251+
gk: Optional[torch.Tensor] = None,
252+
gv: Optional[torch.Tensor] = None,
253+
beta: Optional[torch.Tensor] = None,
218254
scale: float = None,
219255
initial_state: torch.Tensor = None,
220256
output_final_state: bool = False,
221-
cu_seqlens: Optional[torch.LongTensor] = None,
222257
use_qk_l2norm_in_kernel: bool = False,
258+
cu_seqlens: Optional[torch.LongTensor] = None,
223259
) -> Tuple[torch.Tensor, torch.Tensor]:
224260
r"""
225261
Args:
@@ -231,7 +267,11 @@ def fused_recurrent_gated_delta_rule(
231267
values of shape `[B, T, HV, V]`.
232268
GVA is applied if `HV > H`.
233269
g (torch.Tensor):
234-
g (decays) of shape `[B, T, HV]`.
270+
g (decays) of shape `[B, T, HV]`. Default: `None`.
271+
gk (torch.Tensor):
272+
gk (decays) of shape `[B, T, HV, K]`. Default: `None`.
273+
gv (torch.Tensor):
274+
gv (decays) of shape `[B, T, HV, V]`. Default: `None`.
235275
beta (torch.Tensor):
236276
betas of shape `[B, T, HV]`.
237277
scale (Optional[float]):
@@ -243,6 +283,8 @@ def fused_recurrent_gated_delta_rule(
243283
Default: `None`.
244284
output_final_state (Optional[bool]):
245285
Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`.
286+
use_qk_l2norm_in_kernel (Optional[bool]):
287+
Whether to use L2 normalization in the kernel. Default: `False`.
246288
cu_seqlens (torch.LongTensor):
247289
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
248290
consistent with the FlashAttention API.
@@ -275,7 +317,7 @@ def fused_recurrent_gated_delta_rule(
275317
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
276318
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
277319
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
278-
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
320+
>>> o, ht = fused_gated_recurrent_delta_rule(
279321
q, k, v, g, beta,
280322
initial_state=h0,
281323
output_final_state=True,
@@ -295,20 +337,21 @@ def fused_recurrent_gated_delta_rule(
295337
)
296338
if scale is None:
297339
scale = k.shape[-1] ** -0.5
298-
else:
299-
assert scale > 0, "scale must be positive"
300340
if beta is None:
301341
beta = torch.ones_like(q[..., 0])
342+
302343
o, final_state = FusedRecurrentFunction.apply(
303344
q,
304345
k,
305346
v,
306347
g,
348+
gk,
349+
gv,
307350
beta,
308351
scale,
309352
initial_state,
310353
output_final_state,
354+
use_qk_l2norm_in_kernel,
311355
cu_seqlens,
312-
use_qk_l2norm_in_kernel
313356
)
314357
return o, final_state

0 commit comments

Comments
 (0)