@@ -147,7 +147,118 @@ def flash_bwd_post(
147147@tilelang .jit (pass_configs = { 
148148    tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True , 
149149}) 
150- def  flashattn_bwd (batch , heads , seq_len , dim_qk , dim_v , is_causal , block_M , block_N , groups = 1 ):
150+ def  flashattn_bwd_atomic_add (batch ,
151+                              heads ,
152+                              seq_len ,
153+                              dim_qk ,
154+                              dim_v ,
155+                              is_causal ,
156+                              block_M ,
157+                              block_N ,
158+                              threads = 256 ,
159+                              num_stages = 2 ,
160+                              groups = 1 ):
161+     sm_scale  =  (1.0  /  dim_qk )** 0.5 
162+     scale  =  (1.0  /  dim_qk )** 0.5  *  1.44269504   # log2(e) 
163+     head_kv  =  heads  //  groups 
164+     q_shape  =  [batch , seq_len , heads , dim_qk ]
165+     k_shape  =  [batch , seq_len , head_kv , dim_qk ]
166+     v_shape  =  [batch , seq_len , head_kv , dim_v ]
167+     dtype  =  "float16" 
168+     accum_dtype  =  "float" 
169+ 
170+     @T .prim_func  
171+     def  flash_bwd (
172+             Q : T .Tensor (q_shape , dtype ),  # type: ignore 
173+             K : T .Tensor (k_shape , dtype ),  # type: ignore 
174+             V : T .Tensor (v_shape , dtype ),  # type: ignore 
175+             dO : T .Tensor ([batch , seq_len , heads , dim_v ], dtype ),  # type: ignore 
176+             lse : T .Tensor ([batch , heads , seq_len ], accum_dtype ),  # type: ignore 
177+             Delta : T .Tensor ([batch , heads , seq_len ], accum_dtype ),  # type: ignore 
178+             dQ : T .Tensor (q_shape , accum_dtype ),  # type: ignore 
179+             dK : T .Tensor (k_shape , accum_dtype ),  # type: ignore 
180+             dV : T .Tensor (v_shape , accum_dtype ),  # type: ignore 
181+     ):
182+         with  T .Kernel (heads , T .ceildiv (seq_len , block_M ), batch , threads = threads ) as  (bx , by , bz ):
183+             K_shared  =  T .alloc_shared ([block_M , dim_qk ], dtype )
184+             dsT_shared  =  T .alloc_shared ([block_M , block_N ], dtype )
185+             q  =  T .alloc_shared ([block_N , dim_qk ], dtype )
186+             V_shared  =  T .alloc_shared ([block_M , dim_v ], dtype )
187+             qkT  =  T .alloc_fragment ([block_M , block_N ], accum_dtype )
188+             dsT  =  T .alloc_fragment ([block_M , block_N ], accum_dtype )
189+             qkT_cast  =  T .alloc_fragment ([block_M , block_N ], dtype )
190+             dsT_cast  =  T .alloc_fragment ([block_M , block_N ], dtype )
191+             lse_shared  =  T .alloc_shared ([block_N ], accum_dtype )
192+             delta  =  T .alloc_shared ([block_N ], accum_dtype )
193+             do  =  T .alloc_shared ([block_N , dim_v ], dtype )
194+             dv  =  T .alloc_fragment ([block_M , dim_v ], accum_dtype )
195+             dk  =  T .alloc_fragment ([block_M , dim_qk ], accum_dtype )
196+             dq  =  T .alloc_fragment ([block_N , dim_qk ], accum_dtype )
197+             dk_shared  =  T .alloc_shared ([block_M , dim_qk ], accum_dtype )
198+             dv_shared  =  T .alloc_shared ([block_M , dim_v ], accum_dtype )
199+ 
200+             T .annotate_layout ({
201+                 dQ : make_dq_layout (dQ ),
202+                 K_shared : tilelang .layout .make_swizzled_layout (K_shared ),
203+             })
204+ 
205+             T .copy (K [bz , by  *  block_M :(by  +  1 ) *  block_M , bx  //  groups , :], K_shared )
206+             T .copy (V [bz , by  *  block_M :(by  +  1 ) *  block_M , bx  //  groups , :], V_shared )
207+             T .clear (dv )
208+             T .clear (dk )
209+             loop_st  =  T .floordiv (by  *  block_M , block_N ) if  is_causal  else  0 
210+             loop_ed  =  T .ceildiv (seq_len , block_N )
211+             for  k  in  T .Pipelined (loop_st , loop_ed , num_stages = num_stages ):
212+                 T .copy (Q [bz , k  *  block_N :(k  +  1 ) *  block_N , bx , :], q )
213+                 T .clear (qkT )
214+                 T .gemm (K_shared , q , qkT , transpose_B = True , policy = T .GemmWarpPolicy .FullRow )
215+                 T .copy (lse [bz , bx , k  *  block_N :(k  +  1 ) *  block_N ], lse_shared )
216+                 for  i , j  in  T .Parallel (block_M , block_N ):
217+                     qkT [i , j ] =  T .exp2 (qkT [i , j ] *  scale  -  lse_shared [j ])
218+                 if  is_causal :
219+                     for  i , j  in  T .Parallel (block_M , block_N ):
220+                         qkT [i , j ] =  T .if_then_else (by  *  block_M  +  i  <=  k  *  block_N  +  j , qkT [i , j ],
221+                                                    0 )
222+                 T .copy (dO [bz , k  *  block_N :(k  +  1 ) *  block_N , bx , :], do )
223+                 T .clear (dsT )
224+                 T .gemm (V_shared , do , dsT , transpose_B = True , policy = T .GemmWarpPolicy .FullRow )
225+                 T .copy (qkT , qkT_cast )
226+                 T .gemm (qkT_cast , do , dv , policy = T .GemmWarpPolicy .FullRow )
227+ 
228+                 T .copy (Delta [bz , bx , k  *  block_N :(k  +  1 ) *  block_N ], delta )
229+ 
230+                 for  i , j  in  T .Parallel (block_M , block_N ):
231+                     dsT_cast [i , j ] =  qkT [i , j ] *  (dsT [i , j ] -  delta [j ]) *  sm_scale 
232+                 T .gemm (dsT_cast , q , dk , policy = T .GemmWarpPolicy .FullRow )
233+ 
234+                 T .copy (dsT_cast , dsT_shared )
235+                 T .clear (dq )
236+                 T .gemm (dsT_shared , K_shared , dq , transpose_A = True )
237+                 for  i , j  in  T .Parallel (block_N , dim_qk ):
238+                     if  k  *  block_N  +  i  <  seq_len :
239+                         T .atomic_add (dQ [bz , k  *  block_N  +  i , bx , j ], dq [i , j ])
240+             T .copy (dv , dv_shared )
241+             T .atomic_add (dV [bz , by  *  block_M :(by  +  1 ) *  block_M , bx  //  groups , :], dv_shared )
242+             T .copy (dk , dk_shared )
243+             T .atomic_add (dK [bz , by  *  block_M :(by  +  1 ) *  block_M , bx  //  groups , :], dk_shared )
244+ 
245+     return  flash_bwd 
246+ 
247+ 
248+ @tilelang .jit (pass_configs = { 
249+     tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True , 
250+ }) 
251+ def  flashattn_bwd_split (batch ,
252+                         heads ,
253+                         seq_len ,
254+                         dim_qk ,
255+                         dim_v ,
256+                         is_causal ,
257+                         block_M ,
258+                         block_N ,
259+                         threads = 256 ,
260+                         num_stages = 2 ,
261+                         groups = 1 ):
151262    sm_scale  =  (1.0  /  dim_qk )** 0.5 
152263    scale  =  (1.0  /  dim_qk )** 0.5  *  1.44269504   # log2(e) 
153264    head_kv  =  heads  //  groups 
@@ -171,7 +282,7 @@ def flash_bwd(
171282            dK : T .Tensor (dk_shape , dtype ),  # type: ignore 
172283            dV : T .Tensor (dv_shape , dtype ),  # type: ignore 
173284    ):
174-         with  T .Kernel (heads , T .ceildiv (seq_len , block_M ), batch , threads = 128 ) as  (bx , by , bz ):
285+         with  T .Kernel (heads , T .ceildiv (seq_len , block_M ), batch , threads = threads ) as  (bx , by , bz ):
175286            K_shared  =  T .alloc_shared ([block_M , dim_qk ], dtype )
176287            dsT_shared  =  T .alloc_shared ([block_M , block_N ], dtype )
177288            q  =  T .alloc_shared ([block_N , dim_qk ], dtype )
@@ -202,20 +313,20 @@ def flash_bwd(
202313            T .clear (dk )
203314            loop_st  =  T .floordiv (by  *  block_M , block_N ) if  is_causal  else  0 
204315            loop_ed  =  T .ceildiv (seq_len , block_N )
205-             for  k  in  T .Pipelined (loop_st , loop_ed , num_stages = 1 ):
316+             for  k  in  T .Pipelined (loop_st , loop_ed , num_stages = num_stages ):
206317                T .copy (Q [bz , k  *  block_N :(k  +  1 ) *  block_N , bx , :], q )
207318                T .clear (qkT )
208319                T .gemm (K_shared , q , qkT , transpose_B = True , policy = T .GemmWarpPolicy .FullRow )
320+                 T .copy (dO [bz , k  *  block_N :(k  +  1 ) *  block_N , bx , :], do )
321+                 T .clear (dsT )
322+                 T .gemm (V_shared , do , dsT , transpose_B = True , policy = T .GemmWarpPolicy .FullRow )
209323                T .copy (lse [bz , bx , k  *  block_N :(k  +  1 ) *  block_N ], lse_shared )
210324                for  i , j  in  T .Parallel (block_M , block_N ):
211325                    qkT [i , j ] =  T .exp2 (qkT [i , j ] *  scale  -  lse_shared [j ])
212326                if  is_causal :
213327                    for  i , j  in  T .Parallel (block_M , block_N ):
214328                        qkT [i , j ] =  T .if_then_else (by  *  block_M  +  i  <=  k  *  block_N  +  j , qkT [i , j ],
215329                                                   0 )
216-                 T .copy (dO [bz , k  *  block_N :(k  +  1 ) *  block_N , bx , :], do )
217-                 T .clear (dsT )
218-                 T .gemm (V_shared , do , dsT , transpose_B = True , policy = T .GemmWarpPolicy .FullRow )
219330                T .copy (qkT , qkT_cast )
220331                T .gemm (qkT_cast , do , dv , policy = T .GemmWarpPolicy .FullRow )
221332
@@ -244,7 +355,7 @@ def flash_bwd(
244355class  _attention (torch .autograd .Function ):
245356
246357    @staticmethod  
247-     def  forward (ctx , q , k , v , causal , groups = 1 ):
358+     def  forward (ctx , q , k , v , causal , groups = 1 ,  use_atomic = True ):
248359        BATCH , N_CTX , H , D_HEAD_QK  =  q .shape 
249360        D_HEAD_V  =  v .shape [- 1 ]
250361        block_M  =  128 
@@ -253,6 +364,7 @@ def forward(ctx, q, k, v, causal, groups=1):
253364        o , lse  =  mod (q , k , v )
254365        ctx .save_for_backward (q , k , v , o , lse )
255366        ctx .causal  =  causal 
367+         ctx .use_atomic  =  use_atomic 
256368        return  o 
257369
258370    @staticmethod  
@@ -268,23 +380,59 @@ def maybe_contiguous(x):
268380            return  x 
269381
270382        do , q , k , v , o  =  [maybe_contiguous (x ) for  x  in  (do , q , k , v , o )]
271-         block_M  =  64 
383+         block_M  =  128 
272384        block_N  =  32 
273385        mod_prep  =  flashattn_bwd_preprocess (BATCH , H , N_CTX , D_HEAD_V )
274386        mod_post  =  flashattn_bwd_postprocess (BATCH , H , N_CTX , D_HEAD_QK )
275387        delta  =  mod_prep (o , do )
276-         kernel  =  flashattn_bwd (BATCH , H , N_CTX , D_HEAD_QK , D_HEAD_V , ctx .causal , block_M , block_N ,
277-                                groups )
278-         shape_q  =  [BATCH , N_CTX , H , D_HEAD_QK ]
279-         shape_k  =  [groups , BATCH , N_CTX , HEAD_KV , D_HEAD_QK ]  # sum after kernel 
280-         shape_v  =  [groups , BATCH , N_CTX , HEAD_KV , D_HEAD_V ]  # sum after kernel 
281-         dq  =  torch .zeros (shape_q , dtype = torch .float32 , device = q .device )
282-         dk  =  torch .empty (shape_k , dtype = torch .float16 , device = q .device )
283-         dv  =  torch .empty (shape_v , dtype = torch .float16 , device = q .device )
284-         kernel (q , k , v , do , lse , delta , dq , dk , dv )
285-         dq  =  mod_post (dq )
286-         dk , dv  =  dk .sum (0 ), dv .sum (0 )
287-         return  dq , dk , dv , None , None 
388+ 
389+         if  ctx .use_atomic :
390+             kernel  =  flashattn_bwd_atomic_add (
391+                 BATCH ,
392+                 H ,
393+                 N_CTX ,
394+                 D_HEAD_QK ,
395+                 D_HEAD_V ,
396+                 ctx .causal ,
397+                 block_M ,
398+                 block_N ,
399+                 threads = 256 ,
400+                 num_stages = 2 ,
401+                 groups = groups )
402+             shape_q  =  [BATCH , N_CTX , H , D_HEAD_QK ]
403+             shape_k  =  [BATCH , N_CTX , HEAD_KV , D_HEAD_QK ]
404+             shape_v  =  [BATCH , N_CTX , HEAD_KV , D_HEAD_V ]
405+             dq  =  torch .zeros (shape_q , dtype = torch .float32 , device = q .device )
406+             dk  =  torch .zeros (shape_k , dtype = torch .float32 , device = q .device )
407+             dv  =  torch .zeros (shape_v , dtype = torch .float32 , device = q .device )
408+             kernel (q , k , v , do , lse , delta , dq , dk , dv )
409+             dq  =  mod_post (dq )
410+             dk  =  dk .to (torch .float16 )
411+             dv  =  dv .to (torch .float16 )
412+         else :
413+             kernel  =  flashattn_bwd_split (
414+                 BATCH ,
415+                 H ,
416+                 N_CTX ,
417+                 D_HEAD_QK ,
418+                 D_HEAD_V ,
419+                 ctx .causal ,
420+                 block_M ,
421+                 block_N ,
422+                 threads = 256 ,
423+                 num_stages = 2 ,
424+                 groups = groups )
425+             shape_q  =  [BATCH , N_CTX , H , D_HEAD_QK ]
426+             shape_k  =  [groups , BATCH , N_CTX , HEAD_KV , D_HEAD_QK ]  # sum after kernel 
427+             shape_v  =  [groups , BATCH , N_CTX , HEAD_KV , D_HEAD_V ]  # sum after kernel 
428+             dq  =  torch .zeros (shape_q , dtype = torch .float32 , device = q .device )
429+             dk  =  torch .empty (shape_k , dtype = torch .float16 , device = q .device )
430+             dv  =  torch .empty (shape_v , dtype = torch .float16 , device = q .device )
431+             kernel (q , k , v , do , lse , delta , dq , dk , dv )
432+             dq  =  mod_post (dq )
433+             dk , dv  =  dk .sum (0 ), dv .sum (0 )
434+ 
435+         return  dq , dk , dv , None , None , None 
288436
289437
290438attention  =  _attention .apply 
@@ -321,7 +469,8 @@ def main(BATCH: int = 1,
321469         D_HEAD_QK : int  =  192 ,
322470         D_HEAD_V : int  =  128 ,
323471         groups : int  =  16 ,
324-          causal : bool  =  False ):
472+          causal : bool  =  False ,
473+          use_atomic : bool  =  True ):
325474    flops_per_qk  =  2.0  *  BATCH  *  H  *  N_CTX  *  N_CTX  *  D_HEAD_QK 
326475    flops_per_v  =  2.0  *  BATCH  *  H  *  N_CTX  *  N_CTX  *  D_HEAD_V 
327476    total_flops  =  3  *  flops_per_qk  +  2  *  flops_per_v 
@@ -341,7 +490,7 @@ def main(BATCH: int = 1,
341490    dO  =  (
342491        torch .empty (BATCH , N_CTX , H , D_HEAD_V , dtype = torch .half ,
343492                    device = "cuda" ).normal_ ().requires_grad_ ())
344-     O  =  attention (Q , K , V , causal , groups )
493+     O  =  attention (Q , K , V , causal , groups ,  use_atomic )
345494    O .backward (dO , retain_graph = True )
346495    dQ , Q .grad  =  Q .grad .clone (), None 
347496    dK , K .grad  =  K .grad .clone (), None 
@@ -382,7 +531,22 @@ def run1():
382531    parser .add_argument ('--n_ctx' , type = int , default = 1024 , help = 'Context size' )
383532    parser .add_argument ('--d_head_qk' , type = int , default = 192 , help = 'Head dimension for Q/K' )
384533    parser .add_argument ('--d_head_v' , type = int , default = 128 , help = 'Head dimension for V' )
385-     parser .add_argument ('--causal' , type = bool ,  default = False , help = 'Causal flag' )
534+     parser .add_argument ('--causal' , action = 'store_true' , help = 'Causal flag' )
386535    parser .add_argument ('--groups' , type = int , default = 16 , help = 'groups' )
536+     parser .add_argument (
537+         '--use_atomic' , action = 'store_true' , default = False , help = 'Use atomic add for dK/dV' )
538+     parser .add_argument (
539+         '--use_split' , action = 'store_true' , default = False , help = 'Use split for dK/dV' )
387540    args  =  parser .parse_args ()
388-     main (args .batch , args .h , args .n_ctx , args .d_head_qk , args .d_head_v , args .groups , args .causal )
541+ 
542+     # Handle backward compatibility and logic 
543+     if  args .use_split :
544+         use_atomic  =  False 
545+     elif  args .use_atomic :
546+         use_atomic  =  True 
547+     else :
548+         # Default: use atomic 
549+         use_atomic  =  True 
550+ 
551+     main (args .batch , args .h , args .n_ctx , args .d_head_qk , args .d_head_v , args .groups , args .causal ,
552+          use_atomic )
0 commit comments