44
55import  tilelang .language  as  T 
66from  tvm  import  ir 
7- from  tvm .tir  import  PrimExpr , Buffer , BufferLoad ,  BufferRegion , Var , op 
8- from  typing  import  List ,  Union ,  Optional 
7+ from  tvm .tir  import  PrimExpr , Buffer , BufferRegion , Var , op 
8+ from  typing  import  Optional 
99
1010_MEMORY_ORDER_ID_MAP  =  {
1111    "relaxed" : 0 ,
1717}
1818
1919
20- def  atomic_max (dst : Buffer , value : PrimExpr , memory_order : Optional [str ] =  None , return_prev : bool  =  False ) ->  PrimExpr :
20+ def  atomic_max (dst : Buffer ,
21+                value : PrimExpr ,
22+                memory_order : Optional [str ] =  None ,
23+                return_prev : bool  =  False ) ->  PrimExpr :
2124    """ 
2225    Perform an atomic maximum on the value stored at dst with an optional memory-order. 
2326
@@ -61,7 +64,10 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None,
6164                             _MEMORY_ORDER_ID_MAP [memory_order ])
6265
6366
64- def  atomic_min (dst : Buffer , value : PrimExpr , memory_order : Optional [str ] =  None , return_prev : bool  =  False ) ->  PrimExpr :
67+ def  atomic_min (dst : Buffer ,
68+                value : PrimExpr ,
69+                memory_order : Optional [str ] =  None ,
70+                return_prev : bool  =  False ) ->  PrimExpr :
6571    """ 
6672    Atomically update the value at dst to the minimum of its current value and value. 
6773
@@ -107,7 +113,10 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None,
107113                             _MEMORY_ORDER_ID_MAP [memory_order ])
108114
109115
110- def  atomic_add (dst : Buffer , value : PrimExpr , memory_order : Optional [str ] =  None , return_prev : bool  =  False ) ->  PrimExpr :
116+ def  atomic_add (dst : Buffer ,
117+                value : PrimExpr ,
118+                memory_order : Optional [str ] =  None ,
119+                return_prev : bool  =  False ) ->  PrimExpr :
111120    """ 
112121    Atomically add `value` into `dst`, returning a handle to the operation. 
113122
@@ -210,7 +219,8 @@ def _to_region(data, access_type):
210219    # Note: tile-region-based atomic operations don't support return_prev yet 
211220    # This would need to be implemented in the tile runtime 
212221    if  return_prev :
213-         raise  NotImplementedError ("return_prev is not supported for tile-region-based atomic operations" )
222+         raise  NotImplementedError (
223+             "return_prev is not supported for tile-region-based atomic operations" )
214224
215225    return  T .call_intrin ("handle" , op .Op .get ("tl.atomicadd" ), value , dst )
216226
@@ -249,19 +259,7 @@ def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> Pri
249259        >>>             atomic_addx2(global_grads[i, j:j+2], grads[i, j:j+2]) 
250260    """ 
251261    func_name  =  "AtomicAddx2Ret"  if  return_prev  else  "AtomicAddx2" 
252-     return_type  =  "handle"   # For vector operations, we need to determine the appropriate return type 
253- 
254-     if  return_prev :
255-         # For return types, we need to infer the vector type based on dst.dtype 
256-         if  "half"  in  str (dst .dtype ).lower ():
257-             return_type  =  "half2" 
258-         elif  "bfloat16"  in  str (dst .dtype ).lower ():
259-             return_type  =  "__nv_bfloat162" 
260-         elif  "float"  in  str (dst .dtype ).lower ():
261-             return_type  =  "float2" 
262-         else :
263-             return_type  =  "handle"   # Fallback 
264- 
262+     return_type  =  dst .dtype  if  return_prev  else  "handle" 
265263    return  T .call_extern (return_type , func_name , T .address_of (dst ), T .address_of (value ))
266264
267265
@@ -299,15 +297,7 @@ def atomic_addx4(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> Pri
299297        >>> atomic_addx4(rgba_dst, rgba_add)  # Atomic blend of all 4 channels 
300298    """ 
301299    func_name  =  "AtomicAddx4Ret"  if  return_prev  else  "AtomicAddx4" 
302-     return_type  =  "handle" 
303- 
304-     if  return_prev :
305-         # For float4 operations 
306-         if  "float"  in  str (dst .dtype ).lower ():
307-             return_type  =  "float4" 
308-         else :
309-             return_type  =  "handle"   # Fallback 
310- 
300+     return_type  =  "float4"  if  "float"  in  str (dst .dtype ).lower () else  "handle" 
311301    return  T .call_extern (return_type , func_name , T .address_of (dst ), T .address_of (value ))
312302
313303
@@ -402,4 +392,4 @@ def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> P
402392        >>> atomic_store(log_counter, 0)  # Reset counter atomically 
403393    """ 
404394    return  T .call_extern ("handle" , "AtomicStore" , T .address_of (dst ), src ,
405-                          _MEMORY_ORDER_ID_MAP [memory_order ])
395+                          _MEMORY_ORDER_ID_MAP [memory_order ])
0 commit comments