12
12
13
13
14
14
@triton .heuristics ({
15
+ 'USE_G' : lambda args : args ['g' ] is not None ,
16
+ 'USE_GK' : lambda args : args ['gk' ] is not None ,
17
+ 'USE_GV' : lambda args : args ['gv' ] is not None ,
15
18
'USE_INITIAL_STATE' : lambda args : args ['h0' ] is not None ,
16
19
'STORE_FINAL_STATE' : lambda args : args ['ht' ] is not None ,
17
20
'IS_VARLEN' : lambda args : args ['cu_seqlens' ] is not None
@@ -22,6 +25,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
22
25
k ,
23
26
v ,
24
27
g ,
28
+ gk ,
29
+ gv ,
25
30
beta ,
26
31
o ,
27
32
h0 ,
@@ -36,13 +41,16 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
36
41
V : tl .constexpr ,
37
42
BK : tl .constexpr ,
38
43
BV : tl .constexpr ,
39
- USE_INITIAL_STATE : tl .constexpr , # whether to use initial state
40
- STORE_FINAL_STATE : tl .constexpr , # whether to store final state
41
- IS_BETA_HEADWISE : tl .constexpr , # whether beta is headwise vector or scalar ,
44
+ USE_G : tl .constexpr ,
45
+ USE_GK : tl .constexpr ,
46
+ USE_GV : tl .constexpr ,
42
47
USE_QK_L2NORM_IN_KERNEL : tl .constexpr ,
43
- IS_VARLEN : tl .constexpr
48
+ IS_BETA_HEADWISE : tl .constexpr ,
49
+ USE_INITIAL_STATE : tl .constexpr ,
50
+ STORE_FINAL_STATE : tl .constexpr ,
51
+ IS_VARLEN : tl .constexpr ,
44
52
):
45
- i_k , i_v , i_nh = tl .program_id (0 ), tl .program_id (1 ), tl .program_id (2 )
53
+ i_v , i_k , i_nh = tl .program_id (0 ), tl .program_id (1 ), tl .program_id (2 )
46
54
i_n , i_hv = i_nh // HV , i_nh % HV
47
55
i_h = i_hv // (HV // H )
48
56
if IS_VARLEN :
@@ -58,11 +66,16 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
58
66
p_q = q + (bos * H + i_h ) * K + o_k
59
67
p_k = k + (bos * H + i_h ) * K + o_k
60
68
p_v = v + (bos * HV + i_hv ) * V + o_v
69
+ if USE_G :
70
+ p_g = g + bos * HV + i_hv
71
+ if USE_GK :
72
+ p_gk = gk + (bos * HV + i_hv ) * K + o_k
73
+ if USE_GV :
74
+ p_gv = gv + (bos * HV + i_hv ) * V + o_v
61
75
if IS_BETA_HEADWISE :
62
76
p_beta = beta + (bos * HV + i_hv ) * V + o_v
63
77
else :
64
78
p_beta = beta + bos * HV + i_hv
65
- p_g = g + bos * HV + i_hv
66
79
p_o = o + ((i_k * all + bos ) * HV + i_hv ) * V + o_v
67
80
68
81
mask_k = o_k < K
@@ -78,14 +91,22 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
78
91
b_q = tl .load (p_q , mask = mask_k , other = 0 ).to (tl .float32 )
79
92
b_k = tl .load (p_k , mask = mask_k , other = 0 ).to (tl .float32 )
80
93
b_v = tl .load (p_v , mask = mask_v , other = 0 ).to (tl .float32 )
81
- b_g = tl .load (p_g ).to (tl .float32 )
82
94
83
95
if USE_QK_L2NORM_IN_KERNEL :
84
96
b_q = b_q / (tl .sqrt (tl .sum (b_q * b_q )) + 1e-6 )
85
97
b_k = b_k / (tl .sqrt (tl .sum (b_k * b_k )) + 1e-6 )
98
+
86
99
b_q = b_q * scale
87
100
# [BK, BV]
88
- b_h *= exp (b_g )
101
+ if USE_G :
102
+ b_g = tl .load (p_g ).to (tl .float32 )
103
+ b_h *= exp (b_g )
104
+ if USE_GK :
105
+ b_gk = tl .load (p_gk ).to (tl .float32 )
106
+ b_h *= b_gk [:, None ]
107
+ if USE_GV :
108
+ b_gv = tl .load (p_gv ).to (tl .float32 )
109
+ b_h *= b_gv [None , :]
89
110
# [BV]
90
111
b_v -= tl .sum (b_h * b_k [:, None ], 0 )
91
112
if IS_BETA_HEADWISE :
@@ -101,10 +122,15 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
101
122
102
123
p_q += H * K
103
124
p_k += H * K
104
- p_o += HV * V
105
125
p_v += HV * V
106
- p_g += HV
126
+ if USE_G :
127
+ p_g += HV
128
+ if USE_GK :
129
+ p_gk += HV * K
130
+ if USE_GV :
131
+ p_gv += HV * V
107
132
p_beta += HV * (V if IS_BETA_HEADWISE else 1 )
133
+ p_o += HV * V
108
134
109
135
if STORE_FINAL_STATE :
110
136
p_ht = ht + i_nh * K * V + o_k [:, None ] * V + o_v [None , :]
@@ -115,11 +141,13 @@ def fused_recurrent_gated_delta_rule_fwd(
115
141
q : torch .Tensor ,
116
142
k : torch .Tensor ,
117
143
v : torch .Tensor ,
118
- g : torch .Tensor ,
119
- beta : torch .Tensor ,
120
- scale : float ,
121
- initial_state : torch .Tensor ,
122
- output_final_state : bool ,
144
+ g : Optional [torch .Tensor ] = None ,
145
+ gk : Optional [torch .Tensor ] = None ,
146
+ gv : Optional [torch .Tensor ] = None ,
147
+ beta : Optional [torch .Tensor ] = None ,
148
+ scale : float = None ,
149
+ initial_state : torch .Tensor = None ,
150
+ output_final_state : bool = False ,
123
151
use_qk_l2norm_in_kernel : bool = False ,
124
152
cu_seqlens : Optional [torch .LongTensor ] = None ,
125
153
) -> Tuple [torch .Tensor , torch .Tensor ]:
@@ -138,12 +166,14 @@ def fused_recurrent_gated_delta_rule_fwd(
138
166
else :
139
167
final_state = None
140
168
141
- grid = (NK , NV , N * HV )
169
+ grid = (NV , NK , N * HV )
142
170
fused_recurrent_gated_delta_rule_fwd_kernel [grid ](
143
171
q = q ,
144
172
k = k ,
145
173
v = v ,
146
174
g = g ,
175
+ gk = gk ,
176
+ gv = gv ,
147
177
beta = beta ,
148
178
o = o ,
149
179
h0 = initial_state ,
@@ -176,25 +206,29 @@ def forward(
176
206
q : torch .Tensor ,
177
207
k : torch .Tensor ,
178
208
v : torch .Tensor ,
179
- g : torch .Tensor ,
180
- beta : torch .Tensor ,
181
- scale : float ,
182
- initial_state : torch .Tensor ,
183
- output_final_state : bool ,
209
+ g : Optional [torch .Tensor ] = None ,
210
+ gk : Optional [torch .Tensor ] = None ,
211
+ gv : Optional [torch .Tensor ] = None ,
212
+ beta : Optional [torch .Tensor ] = None ,
213
+ scale : float = None ,
214
+ initial_state : torch .Tensor = None ,
215
+ output_final_state : bool = False ,
216
+ use_qk_l2norm_in_kernel : bool = False ,
184
217
cu_seqlens : Optional [torch .LongTensor ] = None ,
185
- use_qk_l2norm_in_kernel : bool = False
186
218
):
187
219
o , final_state = fused_recurrent_gated_delta_rule_fwd (
188
220
q = q ,
189
221
k = k ,
190
222
v = v ,
191
223
g = g ,
224
+ gk = gk ,
225
+ gv = gv ,
192
226
beta = beta ,
193
227
scale = scale ,
194
228
initial_state = initial_state ,
195
229
output_final_state = output_final_state ,
196
230
use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel ,
197
- cu_seqlens = cu_seqlens
231
+ cu_seqlens = cu_seqlens ,
198
232
)
199
233
200
234
return o , final_state
@@ -213,13 +247,15 @@ def fused_recurrent_gated_delta_rule(
213
247
q : torch .Tensor ,
214
248
k : torch .Tensor ,
215
249
v : torch .Tensor ,
216
- g : torch .Tensor ,
217
- beta : torch .Tensor = None ,
250
+ g : Optional [torch .Tensor ] = None ,
251
+ gk : Optional [torch .Tensor ] = None ,
252
+ gv : Optional [torch .Tensor ] = None ,
253
+ beta : Optional [torch .Tensor ] = None ,
218
254
scale : float = None ,
219
255
initial_state : torch .Tensor = None ,
220
256
output_final_state : bool = False ,
221
- cu_seqlens : Optional [torch .LongTensor ] = None ,
222
257
use_qk_l2norm_in_kernel : bool = False ,
258
+ cu_seqlens : Optional [torch .LongTensor ] = None ,
223
259
) -> Tuple [torch .Tensor , torch .Tensor ]:
224
260
r"""
225
261
Args:
@@ -231,7 +267,11 @@ def fused_recurrent_gated_delta_rule(
231
267
values of shape `[B, T, HV, V]`.
232
268
GVA is applied if `HV > H`.
233
269
g (torch.Tensor):
234
- g (decays) of shape `[B, T, HV]`.
270
+ g (decays) of shape `[B, T, HV]`. Default: `None`.
271
+ gk (torch.Tensor):
272
+ gk (decays) of shape `[B, T, HV, K]`. Default: `None`.
273
+ gv (torch.Tensor):
274
+ gv (decays) of shape `[B, T, HV, V]`. Default: `None`.
235
275
beta (torch.Tensor):
236
276
betas of shape `[B, T, HV]`.
237
277
scale (Optional[float]):
@@ -243,6 +283,8 @@ def fused_recurrent_gated_delta_rule(
243
283
Default: `None`.
244
284
output_final_state (Optional[bool]):
245
285
Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`.
286
+ use_qk_l2norm_in_kernel (Optional[bool]):
287
+ Whether to use L2 normalization in the kernel. Default: `False`.
246
288
cu_seqlens (torch.LongTensor):
247
289
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
248
290
consistent with the FlashAttention API.
@@ -275,7 +317,7 @@ def fused_recurrent_gated_delta_rule(
275
317
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
276
318
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
277
319
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
278
- >>> o_var, ht_var = fused_gated_recurrent_delta_rule(
320
+ >>> o, ht = fused_gated_recurrent_delta_rule(
279
321
q, k, v, g, beta,
280
322
initial_state=h0,
281
323
output_final_state=True,
@@ -295,20 +337,21 @@ def fused_recurrent_gated_delta_rule(
295
337
)
296
338
if scale is None :
297
339
scale = k .shape [- 1 ] ** - 0.5
298
- else :
299
- assert scale > 0 , "scale must be positive"
300
340
if beta is None :
301
341
beta = torch .ones_like (q [..., 0 ])
342
+
302
343
o , final_state = FusedRecurrentFunction .apply (
303
344
q ,
304
345
k ,
305
346
v ,
306
347
g ,
348
+ gk ,
349
+ gv ,
307
350
beta ,
308
351
scale ,
309
352
initial_state ,
310
353
output_final_state ,
354
+ use_qk_l2norm_in_kernel ,
311
355
cu_seqlens ,
312
- use_qk_l2norm_in_kernel
313
356
)
314
357
return o , final_state
0 commit comments