55from  tilelang .profiler  import  do_bench 
66import  tilelang .language  as  T 
77import  argparse 
8+ from  typing  import  Optional 
89
910
1011def  get_bwd_configs ():
@@ -23,7 +24,7 @@ def get_bwd_configs():
2324    pass_configs = { 
2425        tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True , 
2526    }, 
26-     compile_flags = ["--use_fast_math"  ,  "- O3""-DENABLE_BF16" ]) 
27+     compile_flags = ["-O3" , "-DENABLE_BF16" ]) 
2728def  flashattn_fwd (
2829        batch ,
2930        heads ,
@@ -143,7 +144,7 @@ def flash_fwd(
143144    pass_configs = { 
144145        tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True , 
145146    }, 
146-     compile_flags = ["--use_fast_math"  ,  "- O3""-DENABLE_BF16" ]) 
147+     compile_flags = ["-O3" , "-DENABLE_BF16" ]) 
147148def  flashattn_bwd_preprocess (batch , heads , seq_len , dim , dtype : str  =  "float16" ):
148149    accum_dtype  =  "float" 
149150    shape  =  [batch , heads , seq_len , dim ]
@@ -183,7 +184,7 @@ def make_dq_layout(dQ):
183184    pass_configs = { 
184185        tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True , 
185186    }, 
186-     compile_flags = ["--use_fast_math"  ,  "- O3""-DENABLE_BF16" ]) 
187+     compile_flags = ["-O3" , "-DENABLE_BF16" ]) 
187188def  flashattn_bwd_postprocess (batch , heads , seq_len , dim , dtype : str  =  "float16" ):
188189    accum_dtype  =  "float" 
189190    shape  =  [batch , heads , seq_len , dim ]
@@ -208,7 +209,7 @@ def flash_bwd_post(
208209    pass_configs = { 
209210        tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True , 
210211    }, 
211-     compile_flags = ["--use_fast_math"  ,  "- O3""-DENABLE_BF16" ]) 
212+     compile_flags = ["-O3" , "-DENABLE_BF16" ]) 
212213def  flashattn_bwd (batch ,
213214                  heads ,
214215                  seq_len ,
@@ -311,8 +312,7 @@ def flash_bwd(
311312                T .clear (dq )
312313                T .gemm (dsT_shared , K_shared , dq , transpose_A = True )
313314                for  i , j  in  T .Parallel (block_N , dim ):
314-                     if  k  *  block_N  +  i  <  seq_len :
315-                         T .atomic_add (dQ [bz , bx , k  *  block_N  +  i , j ], dq [i , j ])
315+                     T .atomic_add (dQ [bz , bx , k  *  block_N  +  i , j ], dq [i , j ])
316316
317317            T .copy (dv , dv_shared )
318318            T .atomic_add (dV [bz , bx  //  groups , by  *  block_M :(by  +  1 ) *  block_M , :], dv_shared )
@@ -405,7 +405,7 @@ def ref_program(query: torch.Tensor,
405405                key : torch .Tensor ,
406406                value : torch .Tensor ,
407407                sinks : torch .Tensor ,
408-                 sliding_window : int   |   None  =  None ,
408+                 sliding_window : Optional [ int ]  =  None ,
409409                dtype : torch .dtype  =  torch .float16 ) ->  torch .Tensor :
410410
411411    key  =  key .transpose (1 , 2 ).contiguous ()
0 commit comments