@@ -113,51 +113,20 @@ def flash_bwd_prep(
113113 return flash_bwd_prep
114114
115115
116- def make_dq_layout (dQ ):
117- # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
118- return T .Layout (dQ .shape ,
119- lambda b , l , h , d : [b , l // 8 , h , d // 8 , (d % 2 ), 4 * (l % 8 ) + (d % 8 ) // 2 ])
120-
121-
122- @tilelang .jit (
123- out_idx = [1 ], pass_configs = {
124- tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True ,
125- })
126- def flashattn_bwd_postprocess (batch , heads , seq_len , dim_qk ):
127- dtype = "float16"
128- accum_dtype = "float"
129- shape = [batch , seq_len , heads , dim_qk ]
130- blk = 64
131-
132- @T .prim_func
133- def flash_bwd_post (
134- dQ : T .Tensor (shape , accum_dtype ), # type: ignore
135- dQ_out : T .Tensor (shape , dtype ), # type: ignore
136- ):
137- with T .Kernel (T .ceildiv (seq_len , blk ), heads , batch , threads = 128 ) as (bx , by , bz ):
138- T .annotate_layout ({dQ : make_dq_layout (dQ )})
139- T .copy (
140- dQ [bz , bx * blk :(bx + 1 ) * blk , by , :],
141- dQ_out [bz , bx * blk :(bx + 1 ) * blk , by , :],
142- )
143-
144- return flash_bwd_post
145-
146-
147116@tilelang .jit (pass_configs = {
148117 tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True ,
149118})
150- def flashattn_bwd_atomic_add (batch ,
151- heads ,
152- seq_len ,
153- dim_qk ,
154- dim_v ,
155- is_causal ,
156- block_M ,
157- block_N ,
158- threads = 256 ,
159- num_stages = 2 ,
160- groups = 1 ):
119+ def flashattn_bwd (batch ,
120+ heads ,
121+ seq_len ,
122+ dim_qk ,
123+ dim_v ,
124+ is_causal ,
125+ block_M ,
126+ block_N ,
127+ threads = 256 ,
128+ num_stages = 2 ,
129+ groups = 1 ):
161130 sm_scale = (1.0 / dim_qk )** 0.5
162131 scale = (1.0 / dim_qk )** 0.5 * 1.44269504 # log2(e)
163132 head_kv = heads // groups
@@ -196,10 +165,13 @@ def flash_bwd(
196165 dq = T .alloc_fragment ([block_N , dim_qk ], accum_dtype )
197166 dk_shared = T .alloc_shared ([block_M , dim_qk ], accum_dtype )
198167 dv_shared = T .alloc_shared ([block_M , dim_v ], accum_dtype )
168+ dq_shared = T .alloc_shared ([block_N , dim_qk ], accum_dtype )
199169
200170 T .annotate_layout ({
201- dQ : make_dq_layout (dQ ),
202171 K_shared : tilelang .layout .make_swizzled_layout (K_shared ),
172+ dq_shared : tilelang .layout .make_swizzled_layout (dq_shared ),
173+ dk_shared : tilelang .layout .make_swizzled_layout (dk_shared ),
174+ dv_shared : tilelang .layout .make_swizzled_layout (dv_shared ),
203175 })
204176
205177 T .copy (K [bz , by * block_M :(by + 1 ) * block_M , bx // groups , :], K_shared )
@@ -244,129 +216,12 @@ def flash_bwd(
244216 T .clear (dq )
245217 T .gemm (dsT_shared , K_shared , dq , transpose_A = True , wg_wait = 1 )
246218 T .wait_wgmma (0 )
247- for i , j in T . Parallel ( block_N , dim_qk ):
248- T .atomic_add (dQ [bz , k * block_N + i , bx , j ], dq [ i , j ] )
219+ T . copy ( dq , dq_shared )
220+ T .atomic_add (dQ [bz , k * block_N :( k + 1 ) * block_N , bx , : ], dq_shared )
249221 T .copy (dv , dv_shared )
250222 T .atomic_add (dV [bz , by * block_M :(by + 1 ) * block_M , bx // groups , :], dv_shared )
251223 T .copy (dk , dk_shared )
252- for i , j in T .Parallel (block_M , dim_qk ):
253- T .atomic_add (dK [bz , by * block_M + i , bx // groups , j ], dk_shared [i , j ])
254-
255- return flash_bwd
256-
257-
258- @tilelang .jit (pass_configs = {
259- tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True ,
260- })
261- def flashattn_bwd_split (batch ,
262- heads ,
263- seq_len ,
264- dim_qk ,
265- dim_v ,
266- is_causal ,
267- block_M ,
268- block_N ,
269- threads = 256 ,
270- num_stages = 2 ,
271- groups = 1 ):
272- sm_scale = (1.0 / dim_qk )** 0.5
273- scale = (1.0 / dim_qk )** 0.5 * 1.44269504 # log2(e)
274- head_kv = heads // groups
275- q_shape = [batch , seq_len , heads , dim_qk ]
276- k_shape = [batch , seq_len , head_kv , dim_qk ]
277- v_shape = [batch , seq_len , head_kv , dim_v ]
278- dk_shape = [groups , batch , seq_len , head_kv , dim_qk ] # sum after kernel
279- dv_shape = [groups , batch , seq_len , head_kv , dim_v ] # sum after kernel
280- dtype = "float16"
281- accum_dtype = "float"
282-
283- @T .prim_func
284- def flash_bwd (
285- Q : T .Tensor (q_shape , dtype ), # type: ignore
286- K : T .Tensor (k_shape , dtype ), # type: ignore
287- V : T .Tensor (v_shape , dtype ), # type: ignore
288- dO : T .Tensor ([batch , seq_len , heads , dim_v ], dtype ), # type: ignore
289- lse : T .Tensor ([batch , heads , seq_len ], accum_dtype ), # type: ignore
290- Delta : T .Tensor ([batch , heads , seq_len ], accum_dtype ), # type: ignore
291- dQ : T .Tensor (q_shape , accum_dtype ), # type: ignore
292- dK : T .Tensor (dk_shape , dtype ), # type: ignore
293- dV : T .Tensor (dv_shape , dtype ), # type: ignore
294- ):
295- with T .Kernel (heads , T .ceildiv (seq_len , block_M ), batch , threads = threads ) as (bx , by , bz ):
296- K_shared = T .alloc_shared ([block_M , dim_qk ], dtype )
297- dsT_shared = T .alloc_shared ([block_M , block_N ], dtype )
298- q = T .alloc_shared ([block_N , dim_qk ], dtype )
299- V_shared = T .alloc_shared ([block_M , dim_v ], dtype )
300- qkT = T .alloc_fragment ([block_M , block_N ], accum_dtype )
301- dsT = T .alloc_fragment ([block_M , block_N ], accum_dtype )
302- qkT_cast = T .alloc_fragment ([block_M , block_N ], dtype )
303- dsT_cast = T .alloc_fragment ([block_M , block_N ], dtype )
304- lse_shared = T .alloc_shared ([block_N ], accum_dtype )
305- delta = T .alloc_shared ([block_N ], accum_dtype )
306- do = T .alloc_shared ([block_N , dim_v ], dtype )
307- dv = T .alloc_fragment ([block_M , dim_v ], accum_dtype )
308- dk = T .alloc_fragment ([block_M , dim_qk ], accum_dtype )
309- dq = T .alloc_fragment ([block_N , dim_qk ], accum_dtype )
310- dv_shared = T .alloc_shared ([block_M , dim_v ], dtype )
311- dk_shared = T .alloc_shared ([block_M , dim_qk ], dtype )
312-
313- T .annotate_layout ({
314- dQ : make_dq_layout (dQ ),
315- K_shared : tilelang .layout .make_swizzled_layout (K_shared ),
316- dv_shared : tilelang .layout .make_swizzled_layout (dv_shared ),
317- dk_shared : tilelang .layout .make_swizzled_layout (dk_shared ),
318- })
319-
320- T .copy (K [bz , by * block_M :(by + 1 ) * block_M , bx // groups , :], K_shared )
321- T .copy (V [bz , by * block_M :(by + 1 ) * block_M , bx // groups , :], V_shared )
322- T .clear (dv )
323- T .clear (dk )
324- loop_st = T .floordiv (by * block_M , block_N ) if is_causal else 0
325- loop_ed = T .ceildiv (seq_len , block_N )
326- for k in T .Pipelined (loop_st , loop_ed , num_stages = num_stages ):
327- T .copy (Q [bz , k * block_N :(k + 1 ) * block_N , bx , :], q )
328- T .clear (qkT )
329- T .gemm (
330- K_shared , q , qkT , transpose_B = True , policy = T .GemmWarpPolicy .FullRow , wg_wait = - 1 )
331- T .copy (dO [bz , k * block_N :(k + 1 ) * block_N , bx , :], do )
332- T .clear (dsT )
333- T .gemm (
334- V_shared ,
335- do ,
336- dsT ,
337- transpose_B = True ,
338- policy = T .GemmWarpPolicy .FullRow ,
339- wg_wait = - 1 )
340- T .wait_wgmma (1 )
341-
342- T .copy (lse [bz , bx , k * block_N :(k + 1 ) * block_N ], lse_shared )
343- for i , j in T .Parallel (block_M , block_N ):
344- qkT [i , j ] = T .exp2 (qkT [i , j ] * scale - lse_shared [j ])
345- if is_causal :
346- for i , j in T .Parallel (block_M , block_N ):
347- qkT [i , j ] = T .if_then_else (by * block_M + i <= k * block_N + j , qkT [i , j ],
348- 0 )
349- T .wait_wgmma (0 )
350- T .copy (qkT , qkT_cast )
351- T .gemm (qkT_cast , do , dv , policy = T .GemmWarpPolicy .FullRow , wg_wait = - 1 )
352-
353- T .copy (Delta [bz , bx , k * block_N :(k + 1 ) * block_N ], delta )
354-
355- for i , j in T .Parallel (block_M , block_N ):
356- dsT_cast [i , j ] = qkT [i , j ] * (dsT [i , j ] - delta [j ]) * sm_scale
357- T .gemm (dsT_cast , q , dk , policy = T .GemmWarpPolicy .FullRow , wg_wait = 1 )
358-
359- T .copy (dsT_cast , dsT_shared )
360- T .clear (dq )
361- T .gemm (dsT_shared , K_shared , dq , transpose_A = True , wg_wait = 1 )
362- T .wait_wgmma (0 )
363- for i , j in T .Parallel (block_N , dim_qk ):
364- T .atomic_add (dQ [bz , k * block_N + i , bx , j ], dq [i , j ])
365-
366- T .copy (dv , dv_shared )
367- T .copy (dv_shared , dV [bx % groups , bz , by * block_M :(by + 1 ) * block_M , bx // groups , :])
368- T .copy (dk , dk_shared )
369- T .copy (dk , dK [bx % groups , bz , by * block_M :(by + 1 ) * block_M , bx // groups , :])
224+ T .atomic_add (dK [bz , by * block_M :(by + 1 ) * block_M , bx // groups , :], dk_shared )
370225
371226 return flash_bwd
372227
@@ -403,54 +258,30 @@ def maybe_contiguous(x):
403258 block_M = 128
404259 block_N = 32
405260 mod_prep = flashattn_bwd_preprocess (BATCH , H , N_CTX , D_HEAD_V )
406- mod_post = flashattn_bwd_postprocess (BATCH , H , N_CTX , D_HEAD_QK )
407261 delta = mod_prep (o , do )
408262
409- if ctx .use_atomic :
410- kernel = flashattn_bwd_atomic_add (
411- BATCH ,
412- H ,
413- N_CTX ,
414- D_HEAD_QK ,
415- D_HEAD_V ,
416- ctx .causal ,
417- block_M ,
418- block_N ,
419- threads = 256 ,
420- num_stages = 2 ,
421- groups = groups )
422- shape_q = [BATCH , N_CTX , H , D_HEAD_QK ]
423- shape_k = [BATCH , N_CTX , HEAD_KV , D_HEAD_QK ]
424- shape_v = [BATCH , N_CTX , HEAD_KV , D_HEAD_V ]
425- dq = torch .zeros (shape_q , dtype = torch .float32 , device = q .device )
426- dk = torch .zeros (shape_k , dtype = torch .float32 , device = q .device )
427- dv = torch .zeros (shape_v , dtype = torch .float32 , device = q .device )
428- kernel (q , k , v , do , lse , delta , dq , dk , dv )
429- dq = mod_post (dq )
430- dk = dk .to (torch .float16 )
431- dv = dv .to (torch .float16 )
432- else :
433- kernel = flashattn_bwd_split (
434- BATCH ,
435- H ,
436- N_CTX ,
437- D_HEAD_QK ,
438- D_HEAD_V ,
439- ctx .causal ,
440- block_M ,
441- block_N ,
442- threads = 256 ,
443- num_stages = 2 ,
444- groups = groups )
445- shape_q = [BATCH , N_CTX , H , D_HEAD_QK ]
446- shape_k = [groups , BATCH , N_CTX , HEAD_KV , D_HEAD_QK ] # sum after kernel
447- shape_v = [groups , BATCH , N_CTX , HEAD_KV , D_HEAD_V ] # sum after kernel
448- dq = torch .zeros (shape_q , dtype = torch .float32 , device = q .device )
449- dk = torch .empty (shape_k , dtype = torch .float16 , device = q .device )
450- dv = torch .empty (shape_v , dtype = torch .float16 , device = q .device )
451- kernel (q , k , v , do , lse , delta , dq , dk , dv )
452- dq = mod_post (dq )
453- dk , dv = dk .sum (0 ), dv .sum (0 )
263+ kernel = flashattn_bwd (
264+ BATCH ,
265+ H ,
266+ N_CTX ,
267+ D_HEAD_QK ,
268+ D_HEAD_V ,
269+ ctx .causal ,
270+ block_M ,
271+ block_N ,
272+ threads = 256 ,
273+ num_stages = 2 ,
274+ groups = groups )
275+ shape_q = [BATCH , N_CTX , H , D_HEAD_QK ]
276+ shape_k = [BATCH , N_CTX , HEAD_KV , D_HEAD_QK ]
277+ shape_v = [BATCH , N_CTX , HEAD_KV , D_HEAD_V ]
278+ dq = torch .zeros (shape_q , dtype = torch .float32 , device = q .device )
279+ dk = torch .zeros (shape_k , dtype = torch .float32 , device = q .device )
280+ dv = torch .zeros (shape_v , dtype = torch .float32 , device = q .device )
281+ kernel (q , k , v , do , lse , delta , dq , dk , dv )
282+ dq = dq .to (torch .float16 )
283+ dk = dk .to (torch .float16 )
284+ dv = dv .to (torch .float16 )
454285
455286 return dq , dk , dv , None , None , None
456287
@@ -489,8 +320,7 @@ def main(BATCH: int = 1,
489320 D_HEAD_QK : int = 192 ,
490321 D_HEAD_V : int = 128 ,
491322 groups : int = 16 ,
492- causal : bool = False ,
493- use_atomic : bool = True ):
323+ causal : bool = False ):
494324 flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
495325 flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
496326 total_flops = 3 * flops_per_qk + 2 * flops_per_v
@@ -510,7 +340,7 @@ def main(BATCH: int = 1,
510340 dO = (
511341 torch .empty (BATCH , N_CTX , H , D_HEAD_V , dtype = torch .half ,
512342 device = "cuda" ).normal_ ().requires_grad_ ())
513- O = attention (Q , K , V , causal , groups , use_atomic )
343+ O = attention (Q , K , V , causal , groups )
514344 O .backward (dO , retain_graph = True )
515345 dQ , Q .grad = Q .grad .clone (), None
516346 dK , K .grad = K .grad .clone (), None
@@ -553,20 +383,6 @@ def run1():
553383 parser .add_argument ('--d_head_v' , type = int , default = 128 , help = 'Head dimension for V' )
554384 parser .add_argument ('--causal' , action = 'store_true' , help = 'Causal flag' )
555385 parser .add_argument ('--groups' , type = int , default = 16 , help = 'groups' )
556- parser .add_argument (
557- '--use_atomic' , action = 'store_true' , default = False , help = 'Use atomic add for dK/dV' )
558- parser .add_argument (
559- '--use_split' , action = 'store_true' , default = False , help = 'Use split for dK/dV' )
560386 args = parser .parse_args ()
561387
562- # Handle backward compatibility and logic
563- if args .use_split :
564- use_atomic = False
565- elif args .use_atomic :
566- use_atomic = True
567- else :
568- # Default: use atomic
569- use_atomic = True
570-
571- main (args .batch , args .h , args .n_ctx , args .d_head_qk , args .d_head_v , args .groups , args .causal ,
572- use_atomic )
388+ main (args .batch , args .h , args .n_ctx , args .d_head_qk , args .d_head_v , args .groups , args .causal )
0 commit comments