@@ -149,110 +149,7 @@ def flash_bwd_post(
149149@tilelang .jit (pass_configs = { 
150150    tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True , 
151151}) 
152- def  flashattn_bwd_atomic_add (batch ,
153-                              heads ,
154-                              seq_len ,
155-                              dim ,
156-                              is_causal ,
157-                              block_M ,
158-                              block_N ,
159-                              threads = 128 ,
160-                              num_stages = 2 ):
161-     sm_scale  =  (1.0  /  dim )** 0.5 
162-     scale  =  (1.0  /  dim )** 0.5  *  1.44269504   # log2(e) 
163-     shape  =  [batch , seq_len , heads , dim ]
164-     dtype  =  "float16" 
165-     accum_dtype  =  "float" 
166- 
167-     @T .prim_func  
168-     def  flash_bwd (
169-             Q : T .Tensor (shape , dtype ),  # type: ignore 
170-             K : T .Tensor (shape , dtype ),  # type: ignore 
171-             V : T .Tensor (shape , dtype ),  # type: ignore 
172-             dO : T .Tensor (shape , dtype ),  # type: ignore 
173-             lse : T .Tensor ([batch , heads , seq_len ], accum_dtype ),  # type: ignore 
174-             Delta : T .Tensor ([batch , heads , seq_len ], accum_dtype ),  # type: ignore 
175-             dQ : T .Tensor (shape , accum_dtype ),  # type: ignore 
176-             dK : T .Tensor (shape , accum_dtype ),  # type: ignore 
177-             dV : T .Tensor (shape , accum_dtype ),  # type: ignore 
178-     ):
179-         with  T .Kernel (heads , T .ceildiv (seq_len , block_M ), batch , threads = threads ) as  (bx , by , bz ):
180-             K_shared  =  T .alloc_shared ([block_M , dim ], dtype )
181-             dsT_shared  =  T .alloc_shared ([block_M , block_N ], dtype )
182-             q  =  T .alloc_shared ([block_N , dim ], dtype )
183-             V_shared  =  T .alloc_shared ([block_M , dim ], dtype )
184-             qkT  =  T .alloc_fragment ([block_M , block_N ], accum_dtype )
185-             dsT  =  T .alloc_fragment ([block_M , block_N ], accum_dtype )
186-             qkT_cast  =  T .alloc_fragment ([block_M , block_N ], dtype )
187-             dsT_cast  =  T .alloc_fragment ([block_M , block_N ], dtype )
188-             lse_shared  =  T .alloc_shared ([block_N ], accum_dtype )
189-             delta  =  T .alloc_shared ([block_N ], accum_dtype )
190-             do  =  T .alloc_shared ([block_N , dim ], dtype )
191-             dv  =  T .alloc_fragment ([block_M , dim ], accum_dtype )
192-             dk  =  T .alloc_fragment ([block_M , dim ], accum_dtype )
193-             dq  =  T .alloc_fragment ([block_N , dim ], accum_dtype )
194-             dk_shared  =  T .alloc_shared ([block_M , dim ], accum_dtype )
195-             dv_shared  =  T .alloc_shared ([block_M , dim ], accum_dtype )
196- 
197-             T .annotate_layout ({
198-                 dQ : make_dq_layout (dQ ),
199-                 K_shared : tilelang .layout .make_swizzled_layout (K_shared ),
200-             })
201-             T .copy (K [bz , by  *  block_M :(by  +  1 ) *  block_M , bx , :], K_shared )
202-             T .copy (V [bz , by  *  block_M :(by  +  1 ) *  block_M , bx , :], V_shared )
203-             T .clear (dv )
204-             T .clear (dk )
205-             loop_st  =  T .floordiv (by  *  block_M , block_N ) if  is_causal  else  0 
206-             loop_ed  =  T .ceildiv (seq_len , block_N )
207-             for  k  in  T .Pipelined (loop_st , loop_ed , num_stages = num_stages ):
208-                 T .copy (Q [bz , k  *  block_N :(k  +  1 ) *  block_N , bx , :], q )
209-                 T .clear (qkT )
210-                 T .gemm (K_shared , q , qkT , transpose_B = True , policy = T .GemmWarpPolicy .FullRow )
211-                 T .copy (lse [bz , bx , k  *  block_N :(k  +  1 ) *  block_N ], lse_shared )
212-                 for  i , j  in  T .Parallel (block_M , block_N ):
213-                     qkT [i , j ] =  T .exp2 (qkT [i , j ] *  scale  -  lse_shared [j ])
214-                 if  is_causal :
215-                     for  i , j  in  T .Parallel (block_M , block_N ):
216-                         qkT [i , j ] =  T .if_then_else (by  *  block_M  +  i  <=  k  *  block_N  +  j , qkT [i , j ],
217-                                                    0 )
218-                 T .copy (dO [bz , k  *  block_N :(k  +  1 ) *  block_N , bx , :], do )
219-                 T .clear (dsT )
220-                 T .gemm (V_shared , do , dsT , transpose_B = True , policy = T .GemmWarpPolicy .FullRow )
221-                 T .copy (qkT , qkT_cast )
222-                 T .gemm (qkT_cast , do , dv , policy = T .GemmWarpPolicy .FullRow )
223- 
224-                 T .copy (Delta [bz , bx , k  *  block_N :(k  +  1 ) *  block_N ], delta )
225- 
226-                 for  i , j  in  T .Parallel (block_M , block_N ):
227-                     dsT_cast [i , j ] =  qkT [i , j ] *  (dsT [i , j ] -  delta [j ]) *  sm_scale 
228-                 T .gemm (dsT_cast , q , dk , policy = T .GemmWarpPolicy .FullRow )
229- 
230-                 T .copy (dsT_cast , dsT_shared )
231-                 T .clear (dq )
232-                 T .gemm (dsT_shared , K_shared , dq , transpose_A = True )
233-                 for  i , j  in  T .Parallel (block_N , dim ):
234-                     if  k  *  block_N  +  i  <  seq_len :
235-                         T .atomic_add (dQ [bz , k  *  block_N  +  i , bx , j ], dq [i , j ])
236-             T .copy (dv , dv_shared )
237-             T .atomic_add (dV [bz , by  *  block_M :(by  +  1 ) *  block_M , bx , :], dv_shared )
238-             T .copy (dk , dk_shared )
239-             T .atomic_add (dK [bz , by  *  block_M :(by  +  1 ) *  block_M , bx , :], dk_shared )
240- 
241-     return  flash_bwd 
242- 
243- 
244- @tilelang .jit (pass_configs = { 
245-     tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True , 
246- }) 
247- def  flashattn_bwd_split (batch ,
248-                         heads ,
249-                         seq_len ,
250-                         dim ,
251-                         is_causal ,
252-                         block_M ,
253-                         block_N ,
254-                         threads = 128 ,
255-                         num_stages = 2 ):
152+ def  flashattn_bwd (batch , heads , seq_len , dim , is_causal , block_M , block_N ):
256153    sm_scale  =  (1.0  /  dim )** 0.5 
257154    scale  =  (1.0  /  dim )** 0.5  *  1.44269504   # log2(e) 
258155    shape  =  [batch , seq_len , heads , dim ]
@@ -271,9 +168,13 @@ def flash_bwd(
271168            dK : T .Tensor (shape , dtype ),  # type: ignore 
272169            dV : T .Tensor (shape , dtype ),  # type: ignore 
273170    ):
274-         with  T .Kernel (heads , T .ceildiv (seq_len , block_M ), batch , threads = threads ) as  (bx , by , bz ):
171+         with  T .Kernel (heads , T .ceildiv (seq_len , block_M ), batch , threads = 128 ) as  (bx , by , bz ):
275172            K_shared  =  T .alloc_shared ([block_M , dim ], dtype )
276173            dsT_shared  =  T .alloc_shared ([block_M , block_N ], dtype )
174+             # should not store K to local if dim is large 
175+             # K_local = T.alloc_fragment([block_M, dim], dtype) 
176+             # K_local_T = T.alloc_fragment([block_M, dim], dtype) 
177+             # V_local = T.alloc_fragment([block_M, dim], dtype) 
277178            q  =  T .alloc_shared ([block_N , dim ], dtype )
278179            V_shared  =  T .alloc_shared ([block_M , dim ], dtype )
279180            qkT  =  T .alloc_fragment ([block_M , block_N ], accum_dtype )
@@ -301,7 +202,7 @@ def flash_bwd(
301202            T .clear (dk )
302203            loop_st  =  T .floordiv (by  *  block_M , block_N ) if  is_causal  else  0 
303204            loop_ed  =  T .ceildiv (seq_len , block_N )
304-             for  k  in  T .Pipelined (loop_st , loop_ed , num_stages = num_stages ):
205+             for  k  in  T .Pipelined (loop_st , loop_ed , num_stages = 2 ):
305206                T .copy (Q [bz , k  *  block_N :(k  +  1 ) *  block_N , bx , :], q )
306207                T .clear (qkT )
307208                T .gemm (K_shared , q , qkT , transpose_B = True , policy = T .GemmWarpPolicy .FullRow )
@@ -328,8 +229,7 @@ def flash_bwd(
328229                T .clear (dq )
329230                T .gemm (dsT_shared , K_shared , dq , transpose_A = True )
330231                for  i , j  in  T .Parallel (block_N , dim ):
331-                     if  k  *  block_N  +  i  <  seq_len :
332-                         T .atomic_add (dQ [bz , k  *  block_N  +  i , bx , j ], dq [i , j ])
232+                     T .atomic_add (dQ [bz , k  *  block_N  +  i , bx , j ], dq [i , j ])
333233            T .copy (dv , dv_shared )
334234            T .copy (dk , dk_shared )
335235            T .copy (dv_shared , dV [bz , by  *  block_M :(by  +  1 ) *  block_M , bx , :])
@@ -341,14 +241,13 @@ def flash_bwd(
341241class  _attention (torch .autograd .Function ):
342242
343243    @staticmethod  
344-     def  forward (ctx , q , k , v , causal ,  use_atomic = True ):
244+     def  forward (ctx , q , k , v , causal ):
345245        BATCH , N_CTX , H , D_HEAD  =  q .shape 
346246        block_M  =  64 
347247        block_N  =  64  if  D_HEAD  <=  128  else  32 
348248        o , lse  =  flashattn_fwd (BATCH , H , N_CTX , D_HEAD , causal , block_M , block_N )(q , k , v )
349249        ctx .save_for_backward (q , k , v , o , lse )
350250        ctx .causal  =  causal 
351-         ctx .use_atomic  =  use_atomic 
352251        return  o 
353252
354253    @staticmethod  
@@ -367,29 +266,14 @@ def maybe_contiguous(x):
367266        kernel_prep  =  flashattn_bwd_preprocess (BATCH , H , N_CTX , D_HEAD )
368267        kernel_post  =  flashattn_bwd_postprocess (BATCH , H , N_CTX , D_HEAD )
369268        delta  =  kernel_prep (o , do )
370- 
371-         if  ctx .use_atomic :
372-             kernel  =  flashattn_bwd_atomic_add (
373-                 BATCH , H , N_CTX , D_HEAD , ctx .causal , block_M , block_N , threads = 128 , num_stages = 2 )
374-             shape  =  [BATCH , N_CTX , H , D_HEAD ]
375-             dq  =  torch .zeros (shape , dtype = torch .float32 , device = q .device )
376-             dk  =  torch .zeros (shape , dtype = torch .float32 , device = q .device )
377-             dv  =  torch .zeros (shape , dtype = torch .float32 , device = q .device )
378-             kernel (q , k , v , do , lse , delta , dq , dk , dv )
379-             dq  =  kernel_post (dq )
380-             dk  =  dk .to (torch .float16 )
381-             dv  =  dv .to (torch .float16 )
382-         else :
383-             kernel  =  flashattn_bwd_split (
384-                 BATCH , H , N_CTX , D_HEAD , ctx .causal , block_M , block_N , threads = 128 , num_stages = 2 )
385-             shape  =  [BATCH , N_CTX , H , D_HEAD ]
386-             dq  =  torch .zeros (shape , dtype = torch .float32 , device = q .device )
387-             dk  =  torch .empty (shape , dtype = torch .float16 , device = q .device )
388-             dv  =  torch .empty (shape , dtype = torch .float16 , device = q .device )
389-             kernel (q , k , v , do , lse , delta , dq , dk , dv )
390-             dq  =  kernel_post (dq )
391- 
392-         return  dq , dk , dv , None , None 
269+         kernel  =  flashattn_bwd (BATCH , H , N_CTX , D_HEAD , ctx .causal , block_M , block_N )
270+         shape  =  [BATCH , N_CTX , H , D_HEAD ]
271+         dq  =  torch .zeros (shape , dtype = torch .float32 , device = q .device )
272+         dk  =  torch .empty (shape , dtype = torch .float16 , device = q .device )
273+         dv  =  torch .empty (shape , dtype = torch .float16 , device = q .device )
274+         kernel (q , k , v , do , lse , delta , dq , dk , dv )
275+         dq  =  kernel_post (dq )
276+         return  dq , dk , dv , None 
393277
394278
395279attention  =  _attention .apply 
@@ -415,9 +299,7 @@ def main(
415299    N_CTX : int  =  1024 ,
416300    D_HEAD : int  =  64 ,
417301    causal : bool  =  False ,
418-     use_atomic : bool  =  True ,
419302):
420-     print (f"Test with use_atomic: { use_atomic }  )
421303    flops_per_matmul  =  2.0  *  BATCH  *  H  *  N_CTX  *  N_CTX  *  D_HEAD 
422304    total_flops  =  5  *  flops_per_matmul 
423305    if  causal :
@@ -428,7 +310,7 @@ def main(
428310    K  =  torch .empty_like (Q ).normal_ ().requires_grad_ ()
429311    V  =  torch .empty_like (Q ).normal_ ().requires_grad_ ()
430312    dO  =  torch .randn_like (Q )
431-     O  =  attention (Q , K , V , causal ,  use_atomic )
313+     O  =  attention (Q , K , V , causal )
432314    O .backward (dO , retain_graph = True )
433315    dQ , Q .grad  =  Q .grad .clone (), None 
434316    dK , K .grad  =  K .grad .clone (), None 
@@ -444,7 +326,6 @@ def main(
444326    assert  torch .allclose (dV , dV_ref , rtol = 1e-2 , atol = 1e-2 )
445327    assert  torch .allclose (dK , dK_ref , rtol = 1e-2 , atol = 1e-2 )
446328    assert  torch .allclose (dQ , dQ_ref , rtol = 1e-2 , atol = 1e-2 )
447-     print ('All checks passed.✅' )
448329
449330    def  run ():
450331        O_ref .backward (dO , retain_graph = True )
@@ -468,20 +349,6 @@ def run1():
468349    parser .add_argument ('--h' , type = int , default = 32 , help = 'Number of heads' )
469350    parser .add_argument ('--n_ctx' , type = int , default = 1024 , help = 'Context size' )
470351    parser .add_argument ('--d_head' , type = int , default = 64 , help = 'Head dimension' )
471-     parser .add_argument ('--causal' , action = 'store_true' , help = 'Causal flag' )
472-     parser .add_argument (
473-         '--use_atomic' , action = 'store_true' , default = False , help = 'Use atomic add for dK/dV' )
474-     parser .add_argument (
475-         '--use_split' , action = 'store_true' , default = False , help = 'Use split for dK/dV' )
352+     parser .add_argument ('--causal' , type = bool , default = False , help = 'Causal flag' )
476353    args  =  parser .parse_args ()
477- 
478-     # Handle backward compatibility and logic 
479-     if  args .use_split :
480-         use_atomic  =  False 
481-     elif  args .use_atomic :
482-         use_atomic  =  True 
483-     else :
484-         # Default: use atomic 
485-         use_atomic  =  True 
486- 
487-     main (args .batch , args .h , args .n_ctx , args .d_head , args .causal , use_atomic )
354+     main (args .batch , args .h , args .n_ctx , args .d_head , args .causal )
0 commit comments