Skip to content

Commit dede99f

Browse files
committed
Refactor atomic operations in CUDA templates for improved readability
- Reformatted atomic operation implementations in atomic.h for better code clarity. - Adjusted function signatures in tilelang's atomic.py to enhance readability by aligning parameters. - Cleaned up unnecessary whitespace and comments in customize.py to streamline the codebase.
1 parent 1d3579f commit dede99f

File tree

5 files changed

+37
-59
lines changed

5 files changed

+37
-59
lines changed

src/tl_templates/cuda/atomic.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@ TL_DEVICE T1 AtomicMaxRet(T1 *address, T2 val,
5959
using NT1 = typename normalize_atomic_type<T1>::type;
6060
if constexpr (std::is_same_v<NT1, half> ||
6161
std::is_same_v<NT1, __nv_bfloat16>) {
62-
return static_cast<T1>(atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
62+
return static_cast<T1>(
63+
atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
6364
} else {
6465
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
65-
return static_cast<T1>(aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
66+
return static_cast<T1>(
67+
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
6668
}
6769
}
6870

@@ -85,10 +87,12 @@ TL_DEVICE T1 AtomicMinRet(T1 *address, T2 val,
8587
using NT1 = typename normalize_atomic_type<T1>::type;
8688
if constexpr (std::is_same_v<NT1, half> ||
8789
std::is_same_v<NT1, __nv_bfloat16>) {
88-
return static_cast<T1>(atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
90+
return static_cast<T1>(
91+
atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
8992
} else {
9093
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
91-
return static_cast<T1>(aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
94+
return static_cast<T1>(
95+
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
9296
}
9397
}
9498

@@ -111,10 +115,12 @@ TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val,
111115
using NT1 = typename normalize_atomic_type<T1>::type;
112116
if constexpr (std::is_same_v<NT1, half> ||
113117
std::is_same_v<NT1, __nv_bfloat16>) {
114-
return static_cast<T1>(atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
118+
return static_cast<T1>(
119+
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
115120
} else {
116121
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
117-
return static_cast<T1>(aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
122+
return static_cast<T1>(
123+
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
118124
}
119125
}
120126

src/tl_templates/cuda/common.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
#include <cuda_runtime.h>
55
#endif
66

7+
#include "atomic.h"
78
#include <cutlass/fast_math.h>
89
#include <cutlass/numeric_types.h>
910
#include <math_constants.h>
10-
#include "atomic.h"
1111

1212
using cutlass::bfloat16_t;
1313
using cutlass::half_t;
@@ -138,7 +138,6 @@ TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) {
138138
return smem_int;
139139
}
140140

141-
142141
// DP4A
143142
template <typename InDatatype, typename OutDatatype>
144143
TL_DEVICE /**

src/transform/legalize_safe_memory_access.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ class SafeMemorysRewriter : public StmtExprMutator {
235235

236236
bool IsLocalBuffer(const Buffer &buffer) {
237237
String scope = buffer.scope();
238-
return scope == "local" || scope == "local.fragment" || scope == "local.var";
238+
return scope == "local" || scope == "local.fragment" ||
239+
scope == "local.var";
239240
}
240241

241242
bool isSharedBuffer(const Buffer &buffer) {

tilelang/language/atomic.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import tilelang.language as T
66
from tvm import ir
7-
from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, Var, op
8-
from typing import List, Union, Optional
7+
from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op
8+
from typing import Optional
99

1010
_MEMORY_ORDER_ID_MAP = {
1111
"relaxed": 0,
@@ -17,7 +17,10 @@
1717
}
1818

1919

20-
def atomic_max(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None, return_prev: bool = False) -> PrimExpr:
20+
def atomic_max(dst: Buffer,
21+
value: PrimExpr,
22+
memory_order: Optional[str] = None,
23+
return_prev: bool = False) -> PrimExpr:
2124
"""
2225
Perform an atomic maximum on the value stored at dst with an optional memory-order.
2326
@@ -61,7 +64,10 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None,
6164
_MEMORY_ORDER_ID_MAP[memory_order])
6265

6366

64-
def atomic_min(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None, return_prev: bool = False) -> PrimExpr:
67+
def atomic_min(dst: Buffer,
68+
value: PrimExpr,
69+
memory_order: Optional[str] = None,
70+
return_prev: bool = False) -> PrimExpr:
6571
"""
6672
Atomically update the value at dst to the minimum of its current value and value.
6773
@@ -107,7 +113,10 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None,
107113
_MEMORY_ORDER_ID_MAP[memory_order])
108114

109115

110-
def atomic_add(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None, return_prev: bool = False) -> PrimExpr:
116+
def atomic_add(dst: Buffer,
117+
value: PrimExpr,
118+
memory_order: Optional[str] = None,
119+
return_prev: bool = False) -> PrimExpr:
111120
"""
112121
Atomically add `value` into `dst`, returning a handle to the operation.
113122
@@ -210,7 +219,8 @@ def _to_region(data, access_type):
210219
# Note: tile-region-based atomic operations don't support return_prev yet
211220
# This would need to be implemented in the tile runtime
212221
if return_prev:
213-
raise NotImplementedError("return_prev is not supported for tile-region-based atomic operations")
222+
raise NotImplementedError(
223+
"return_prev is not supported for tile-region-based atomic operations")
214224

215225
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst)
216226

@@ -249,19 +259,7 @@ def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> Pri
249259
>>> atomic_addx2(global_grads[i, j:j+2], grads[i, j:j+2])
250260
"""
251261
func_name = "AtomicAddx2Ret" if return_prev else "AtomicAddx2"
252-
return_type = "handle" # For vector operations, we need to determine the appropriate return type
253-
254-
if return_prev:
255-
# For return types, we need to infer the vector type based on dst.dtype
256-
if "half" in str(dst.dtype).lower():
257-
return_type = "half2"
258-
elif "bfloat16" in str(dst.dtype).lower():
259-
return_type = "__nv_bfloat162"
260-
elif "float" in str(dst.dtype).lower():
261-
return_type = "float2"
262-
else:
263-
return_type = "handle" # Fallback
264-
262+
return_type = dst.dtype if return_prev else "handle"
265263
return T.call_extern(return_type, func_name, T.address_of(dst), T.address_of(value))
266264

267265

@@ -299,15 +297,7 @@ def atomic_addx4(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> Pri
299297
>>> atomic_addx4(rgba_dst, rgba_add) # Atomic blend of all 4 channels
300298
"""
301299
func_name = "AtomicAddx4Ret" if return_prev else "AtomicAddx4"
302-
return_type = "handle"
303-
304-
if return_prev:
305-
# For float4 operations
306-
if "float" in str(dst.dtype).lower():
307-
return_type = "float4"
308-
else:
309-
return_type = "handle" # Fallback
310-
300+
return_type = "float4" if "float" in str(dst.dtype).lower() else "handle"
311301
return T.call_extern(return_type, func_name, T.address_of(dst), T.address_of(value))
312302

313303

@@ -402,4 +392,4 @@ def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> P
402392
>>> atomic_store(log_counter, 0) # Reset counter atomically
403393
"""
404394
return T.call_extern("handle", "AtomicStore", T.address_of(dst), src,
405-
_MEMORY_ORDER_ID_MAP[memory_order])
395+
_MEMORY_ORDER_ID_MAP[memory_order])

tilelang/language/customize.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
# Copyright (c) Tile-AI Corporation.
2-
# Licensed under the MIT License.
31
"""The language interface for tl programs."""
42

53
import tilelang.language as T
6-
from tvm import ir
7-
from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, Var, op
8-
from typing import List, Union, Optional
9-
from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store
10-
4+
from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, op
5+
from typing import List, Union
6+
from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401
117

128

139
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
@@ -97,16 +93,6 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
9793
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
9894

9995

100-
101-
102-
103-
104-
105-
106-
107-
108-
109-
11096
def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr:
11197
"""Perform a 4-element dot product with accumulation (DP4A).
11298
@@ -163,7 +149,3 @@ def view(src: Buffer,
163149
if dtype is None:
164150
dtype = src.dtype
165151
return T.Tensor(shape, dtype, src.data)
166-
167-
168-
169-

0 commit comments

Comments
 (0)