Skip to content

Commit 7d389a4

Browse files
authored
[Bugfix] Correctly construct the argument list for atomic add based on the vector size (#1137)
* atomic_fix * atomic_fix
1 parent 853f9c3 commit 7d389a4

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/transform/atomicadd_vectorize.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,21 +231,25 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator {
231231
// Ref: src/tl_templates/cuda/atomic.h::AtomicAdd
232232
const IntImm memory_order =
233233
node->args.size() >= 3 ? Downcast<IntImm>(node->args[2]) : IntImm(0);
234-
234+
Array<PrimExpr> new_args;
235235
Call address_of_dst =
236236
Call(DataType::Handle(), builtin::address_of(), {dst_node});
237237
Call address_of_value =
238238
Call(DataType::Handle(), builtin::address_of(), {value_node});
239-
Array<PrimExpr> new_args;
240239
if (vector_size_ == 4) {
241240
new_args.push_back(StringImm("AtomicAddx4"));
241+
new_args.push_back(address_of_dst);
242+
new_args.push_back(address_of_value);
242243
} else if (vector_size_ == 2) {
243244
new_args.push_back(StringImm("AtomicAddx2"));
245+
new_args.push_back(address_of_dst);
246+
new_args.push_back(address_of_value);
244247
} else {
245248
new_args.push_back(StringImm("AtomicAdd"));
249+
new_args.push_back(dst_node);
250+
new_args.push_back(value_node);
246251
}
247-
new_args.push_back(address_of_dst);
248-
new_args.push_back(address_of_value);
252+
249253
new_args.push_back(memory_order);
250254

251255
Call new_call =

0 commit comments

Comments
 (0)