Skip to content

Commit e11b4c4

Browse files
committed
Update atomic operations and tests for improved functionality
- Updated atomic operations in CUDA templates to remove unnecessary address_of calls, enhancing performance and readability. - Refactored atomic operation signatures in tilelang's atomic.py to accept references instead of pointers. - Added new atomic operations and corresponding test cases for atomic add, max, min, and load/store functionalities in the testing suite. - Updated the TVM subproject to the latest commit for better compatibility.
1 parent 864b8ee commit e11b4c4

File tree

5 files changed

+421
-52
lines changed

5 files changed

+421
-52
lines changed

src/op/atomic_add.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
293293
if (dst_predicate.defined())
294294
dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype));
295295

296-
Call address_of_value =
297-
tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_value});
298-
299-
new_args.push_back(address_of_value);
296+
new_args.push_back(dst_value);
300297
new_args.push_back(src_value);
301298

302299
Call atomicadd_call =

src/tl_templates/cuda/atomic.h

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ template <> TL_DEVICE __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) {
4141
#endif
4242

4343
template <typename T1, typename T2>
44-
TL_DEVICE void AtomicMax(T1 *address, T2 val,
44+
TL_DEVICE void AtomicMax(T1 &ref, T2 val,
4545
int memory_order = int(cuda::memory_order_relaxed)) {
4646
using NT1 = typename normalize_atomic_type<T1>::type;
47+
T1 *address = &ref;
4748
if constexpr (std::is_same_v<NT1, half> ||
4849
std::is_same_v<NT1, __nv_bfloat16>) {
4950
atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
@@ -54,9 +55,10 @@ TL_DEVICE void AtomicMax(T1 *address, T2 val,
5455
}
5556

5657
template <typename T1, typename T2>
57-
TL_DEVICE T1 AtomicMaxRet(T1 *address, T2 val,
58+
TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val,
5859
int memory_order = int(cuda::memory_order_relaxed)) {
5960
using NT1 = typename normalize_atomic_type<T1>::type;
61+
T1 *address = &ref;
6062
if constexpr (std::is_same_v<NT1, half> ||
6163
std::is_same_v<NT1, __nv_bfloat16>) {
6264
return static_cast<T1>(
@@ -69,9 +71,10 @@ TL_DEVICE T1 AtomicMaxRet(T1 *address, T2 val,
6971
}
7072

7173
template <typename T1, typename T2>
72-
TL_DEVICE void AtomicMin(T1 *address, T2 val,
74+
TL_DEVICE void AtomicMin(T1 &ref, T2 val,
7375
int memory_order = int(cuda::memory_order_relaxed)) {
7476
using NT1 = typename normalize_atomic_type<T1>::type;
77+
T1 *address = &ref;
7578
if constexpr (std::is_same_v<NT1, half> ||
7679
std::is_same_v<NT1, __nv_bfloat16>) {
7780
atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
@@ -82,9 +85,10 @@ TL_DEVICE void AtomicMin(T1 *address, T2 val,
8285
}
8386

8487
template <typename T1, typename T2>
85-
TL_DEVICE T1 AtomicMinRet(T1 *address, T2 val,
88+
TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val,
8689
int memory_order = int(cuda::memory_order_relaxed)) {
8790
using NT1 = typename normalize_atomic_type<T1>::type;
91+
T1 *address = &ref;
8892
if constexpr (std::is_same_v<NT1, half> ||
8993
std::is_same_v<NT1, __nv_bfloat16>) {
9094
return static_cast<T1>(
@@ -97,9 +101,10 @@ TL_DEVICE T1 AtomicMinRet(T1 *address, T2 val,
97101
}
98102

99103
template <typename T1, typename T2>
100-
TL_DEVICE void AtomicAdd(T1 *address, T2 val,
104+
TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
101105
int memory_order = int(cuda::memory_order_relaxed)) {
102106
using NT1 = typename normalize_atomic_type<T1>::type;
107+
T1 *address = &ref;
103108
if constexpr (std::is_same_v<NT1, half> ||
104109
std::is_same_v<NT1, __nv_bfloat16>) {
105110
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
@@ -110,9 +115,10 @@ TL_DEVICE void AtomicAdd(T1 *address, T2 val,
110115
}
111116

112117
template <typename T1, typename T2>
113-
TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val,
118+
TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
114119
int memory_order = int(cuda::memory_order_relaxed)) {
115120
using NT1 = typename normalize_atomic_type<T1>::type;
121+
T1 *address = &ref;
116122
if constexpr (std::is_same_v<NT1, half> ||
117123
std::is_same_v<NT1, __nv_bfloat16>) {
118124
return static_cast<T1>(
@@ -124,60 +130,60 @@ TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val,
124130
}
125131
}
126132

127-
TL_DEVICE void AtomicAddx2(half_t *address, half_t *val) {
128-
atomicAdd(reinterpret_cast<half2 *>(address),
133+
TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val) {
134+
atomicAdd(reinterpret_cast<half2 *>(ref),
129135
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
130136
}
131137

132-
TL_DEVICE half2 AtomicAddx2Ret(half_t *address, half_t *val) {
133-
return atomicAdd(reinterpret_cast<half2 *>(address),
138+
TL_DEVICE half2 AtomicAddx2Ret(half_t *ref, half_t *val) {
139+
return atomicAdd(reinterpret_cast<half2 *>(ref),
134140
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
135141
}
136142

137143
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
138-
TL_DEVICE void AtomicAddx2(bfloat16_t *address, bfloat16_t *val) {
144+
TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val) {
139145
atomicAdd(
140-
reinterpret_cast<__nv_bfloat162 *>(address),
146+
reinterpret_cast<__nv_bfloat162 *>(ref),
141147
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
142148
}
143149

144-
TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *address, bfloat16_t *val) {
150+
TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val) {
145151
return atomicAdd(
146-
reinterpret_cast<__nv_bfloat162 *>(address),
152+
reinterpret_cast<__nv_bfloat162 *>(ref),
147153
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
148154
}
149155
#endif
150156

151157
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
152-
TL_DEVICE void AtomicAddx2(float *address, float *val) {
153-
atomicAdd(reinterpret_cast<float2 *>(address),
158+
TL_DEVICE void AtomicAddx2(float *ref, float *val) {
159+
atomicAdd(reinterpret_cast<float2 *>(ref),
154160
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
155161
}
156162

157-
TL_DEVICE float2 AtomicAddx2Ret(float *address, float *val) {
158-
return atomicAdd(reinterpret_cast<float2 *>(address),
163+
TL_DEVICE float2 AtomicAddx2Ret(float *ref, float *val) {
164+
return atomicAdd(reinterpret_cast<float2 *>(ref),
159165
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
160166
}
161167

162-
TL_DEVICE void AtomicAddx4(float *address, float *val) {
163-
atomicAdd(reinterpret_cast<float4 *>(address),
168+
TL_DEVICE void AtomicAddx4(float *ref, float *val) {
169+
atomicAdd(reinterpret_cast<float4 *>(ref),
164170
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
165171
}
166172

167-
TL_DEVICE float4 AtomicAddx4Ret(float *address, float *val) {
168-
return atomicAdd(reinterpret_cast<float4 *>(address),
173+
TL_DEVICE float4 AtomicAddx4Ret(float *ref, float *val) {
174+
return atomicAdd(reinterpret_cast<float4 *>(ref),
169175
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
170176
}
171177
#endif
172178

173-
template <typename T> TL_DEVICE T AtomicLoad(T *address, int memory_order) {
174-
cuda::atomic_ref<T, cuda::thread_scope_device> aref(*address);
179+
template <typename T> TL_DEVICE T AtomicLoad(T &ref, int memory_order) {
180+
cuda::atomic_ref<T, cuda::thread_scope_device> aref(ref);
175181
return aref.load(cuda::memory_order(memory_order));
176182
}
177183

178184
template <typename T1, typename T2>
179-
TL_DEVICE void AtomicStore(T1 *address, T2 value, int memory_order) {
185+
TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) {
180186
using NT1 = typename normalize_atomic_type<T1>::type;
181-
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
187+
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(ref);
182188
aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order));
183-
}
189+
}

src/transform/atomicadd_vectorize.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,8 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator {
219219
// bx * stride_x + (i % (stride_x / (tx_extent *
220220
// vector_size_)) * (tx_extent * vector_size_) + (tx_var_ %
221221
// (stride / vector_size_)) * vector_size_]
222-
const CallNode *addr_call = node->args[1].as<CallNode>();
223-
if (!addr_call || addr_call->op != builtin::address_of() ||
224-
addr_call->args.size() != 1) {
225-
return StmtExprMutator::VisitExpr_(node);
226-
}
227222
const BufferLoadNode *old_dst_node =
228-
addr_call->args[0].as<BufferLoadNode>();
223+
node->args[1].as<BufferLoadNode>();
229224
const BufferLoadNode *old_value_node =
230225
node->args[2].as<BufferLoadNode>();
231226
if (!old_dst_node || !old_value_node) {

0 commit comments

Comments
 (0)