Skip to content

Commit 1d4b718

Browse files
authored
[BugFix] Add memory order argument for non-vectorized atomic add (#1081)
* [BugFix] Add memory order argument for non-vectorized atomic add * [Lint] * [BugFix] Memory order * [Lint] * [BugFix] Argument in cuda template * [Lint]
1 parent 792e5d5 commit 1d4b718

File tree

6 files changed

+50
-17
lines changed

6 files changed

+50
-17
lines changed

src/op/atomic_add.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,12 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
5858
if (args.size() >= 3) {
5959
node->use_tma = Downcast<IntImm>(args[2]);
6060
}
61+
node->memory_order = IntImm(0);
6162
if (args.size() >= 4) {
62-
node->coalesced_width = Downcast<IntImm>(args[3]);
63+
node->memory_order = Downcast<IntImm>(args[3]);
64+
}
65+
if (args.size() >= 5) {
66+
node->coalesced_width = Downcast<IntImm>(args[4]);
6367
}
6468
data_ = std::move(node);
6569
}
@@ -285,6 +289,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
285289

286290
new_args.push_back(dst_value);
287291
new_args.push_back(src_value);
292+
new_args.push_back(memory_order);
288293

289294
Call atomicadd_call =
290295
tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args);

src/op/atomic_add.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class AtomicAddNode : public TileOperatorNode {
2222
dst_range; ///< Access ranges for source and destination
2323
IntImm use_tma; ///< Whether to use TMA for memory operations
2424
IntImm coalesced_width; ///< Width for memory coalescing optimization
25+
IntImm memory_order; ///< Memory order for atomic operations
2526

2627
mutable ParallelOp par_op_; ///< Associated parallel operation
2728
static constexpr const char *_type_key = "tl.AtomicAdd";
@@ -41,15 +42,17 @@ class AtomicAddNode : public TileOperatorNode {
4142
.def_ro("src_range", &AtomicAddNode::src_range)
4243
.def_ro("dst_range", &AtomicAddNode::dst_range)
4344
.def_ro("use_tma", &AtomicAddNode::use_tma)
44-
.def_ro("coalesced_width", &AtomicAddNode::coalesced_width);
45+
.def_ro("coalesced_width", &AtomicAddNode::coalesced_width)
46+
.def_ro("memory_order", &AtomicAddNode::memory_order);
4547
}
4648

4749
bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const {
4850
return equal(src, other->src) && equal(dst, other->dst) &&
4951
equal(src_range, other->src_range) &&
5052
equal(dst_range, other->dst_range) &&
5153
equal(use_tma, other->use_tma) &&
52-
equal(coalesced_width, other->coalesced_width);
54+
equal(coalesced_width, other->coalesced_width) &&
55+
equal(memory_order, other->memory_order);
5356
}
5457

5558
void SHashReduce(SHashReducer hash_reduce) const {
@@ -59,6 +62,7 @@ class AtomicAddNode : public TileOperatorNode {
5962
hash_reduce(dst_range);
6063
hash_reduce(use_tma);
6164
hash_reduce(coalesced_width);
65+
hash_reduce(memory_order);
6266
}
6367

6468
static constexpr bool _type_has_method_sequal_reduce = true;

src/op/builtin.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
296296
Integer(CallEffectKind::kOpaque));
297297

298298
TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op)
299-
.set_num_inputs(2)
299+
.set_num_inputs(3)
300300
.set_attr<TCallEffectKind>("TCallEffectKind",
301301
Integer(CallEffectKind::kOpaque));
302302

src/tl_templates/cuda/atomic.h

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
105105
int memory_order = int(cuda::memory_order_relaxed)) {
106106
using NT1 = typename normalize_atomic_type<T1>::type;
107107
T1 *address = &ref;
108-
if constexpr (std::is_same_v<NT1, half> ||
109-
std::is_same_v<NT1, __nv_bfloat16>) {
108+
if constexpr ((std::is_same_v<NT1, half> ||
109+
std::is_same_v<NT1, __nv_bfloat16>)&&memory_order ==
110+
int(cuda::memory_order_relaxed)) {
110111
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
111112
} else {
112113
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
@@ -119,8 +120,9 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
119120
int memory_order = int(cuda::memory_order_relaxed)) {
120121
using NT1 = typename normalize_atomic_type<T1>::type;
121122
T1 *address = &ref;
122-
if constexpr (std::is_same_v<NT1, half> ||
123-
std::is_same_v<NT1, __nv_bfloat16>) {
123+
if constexpr ((std::is_same_v<NT1, half> ||
124+
std::is_same_v<NT1, __nv_bfloat16>)&&memory_order ==
125+
int(cuda::memory_order_relaxed)) {
124126
return static_cast<T1>(
125127
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
126128
} else {
@@ -130,47 +132,60 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
130132
}
131133
}
132134

133-
TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val) {
135+
// TODO add memory_order for vectorized atomic add
136+
TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val,
137+
int memory_order = int(cuda::memory_order_relaxed)) {
134138
atomicAdd(reinterpret_cast<half2 *>(ref),
135139
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
136140
}
137141

138-
TL_DEVICE half2 AtomicAddx2Ret(half_t *ref, half_t *val) {
142+
TL_DEVICE half2
143+
AtomicAddx2Ret(half_t *ref, half_t *val,
144+
int memory_order = int(cuda::memory_order_relaxed)) {
139145
return atomicAdd(reinterpret_cast<half2 *>(ref),
140146
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
141147
}
142148

143149
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
144-
TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val) {
150+
TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val,
151+
int memory_order = int(cuda::memory_order_relaxed)) {
145152
atomicAdd(
146153
reinterpret_cast<__nv_bfloat162 *>(ref),
147154
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
148155
}
149156

150-
TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val) {
157+
TL_DEVICE __nv_bfloat162
158+
AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val,
159+
int memory_order = int(cuda::memory_order_relaxed)) {
151160
return atomicAdd(
152161
reinterpret_cast<__nv_bfloat162 *>(ref),
153162
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
154163
}
155164
#endif
156165

157166
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
158-
TL_DEVICE void AtomicAddx2(float *ref, float *val) {
167+
TL_DEVICE void AtomicAddx2(float *ref, float *val,
168+
int memory_order = int(cuda::memory_order_relaxed)) {
159169
atomicAdd(reinterpret_cast<float2 *>(ref),
160170
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
161171
}
162172

163-
TL_DEVICE float2 AtomicAddx2Ret(float *ref, float *val) {
173+
TL_DEVICE float2
174+
AtomicAddx2Ret(float *ref, float *val,
175+
int memory_order = int(cuda::memory_order_relaxed)) {
164176
return atomicAdd(reinterpret_cast<float2 *>(ref),
165177
static_cast<float2>(*reinterpret_cast<float2 *>(val)));
166178
}
167179

168-
TL_DEVICE void AtomicAddx4(float *ref, float *val) {
180+
TL_DEVICE void AtomicAddx4(float *ref, float *val,
181+
int memory_order = int(cuda::memory_order_relaxed)) {
169182
atomicAdd(reinterpret_cast<float4 *>(ref),
170183
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
171184
}
172185

173-
TL_DEVICE float4 AtomicAddx4Ret(float *ref, float *val) {
186+
TL_DEVICE float4
187+
AtomicAddx4Ret(float *ref, float *val,
188+
int memory_order = int(cuda::memory_order_relaxed)) {
174189
return atomicAdd(reinterpret_cast<float4 *>(ref),
175190
static_cast<float4>(*reinterpret_cast<float4 *>(val)));
176191
}

src/transform/atomicadd_vectorize.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,10 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator {
227227
if (legal_vectorize) {
228228
const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]);
229229
const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]);
230+
// The default memory order is relaxed
231+
// Ref: src/tl_templates/cuda/atomic.h::AtomicAdd
232+
const IntImm memory_order =
233+
node->args.size() >= 3 ? Downcast<IntImm>(node->args[2]) : IntImm(0);
230234

231235
Call address_of_dst =
232236
Call(DataType::Handle(), builtin::address_of(), {dst_node});
@@ -242,6 +246,7 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator {
242246
}
243247
new_args.push_back(address_of_dst);
244248
new_args.push_back(address_of_value);
249+
new_args.push_back(memory_order);
245250

246251
Call new_call =
247252
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);

tilelang/language/atomic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,11 @@ def _to_region(data, access_type):
227227
raise NotImplementedError(
228228
"return_prev is not supported for tile-region-based atomic operations")
229229

230-
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma)
230+
if memory_order is None:
231+
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma, 0)
232+
else:
233+
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma,
234+
_MEMORY_ORDER_ID_MAP[memory_order])
231235

232236

233237
def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr:

0 commit comments

Comments
 (0)