File tree Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -231,21 +231,25 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator {
231231 // Ref: src/tl_templates/cuda/atomic.h::AtomicAdd
232232 const IntImm memory_order =
233233 node->args .size () >= 3 ? Downcast<IntImm>(node->args [2 ]) : IntImm (0 );
234-
234+ Array<PrimExpr> new_args;
235235 Call address_of_dst =
236236 Call (DataType::Handle (), builtin::address_of (), {dst_node});
237237 Call address_of_value =
238238 Call (DataType::Handle (), builtin::address_of (), {value_node});
239- Array<PrimExpr> new_args;
240239 if (vector_size_ == 4 ) {
241240 new_args.push_back (StringImm (" AtomicAddx4" ));
241+ new_args.push_back (address_of_dst);
242+ new_args.push_back (address_of_value);
242243 } else if (vector_size_ == 2 ) {
243244 new_args.push_back (StringImm (" AtomicAddx2" ));
245+ new_args.push_back (address_of_dst);
246+ new_args.push_back (address_of_value);
244247 } else {
245248 new_args.push_back (StringImm (" AtomicAdd" ));
249+ new_args.push_back (dst_node);
250+ new_args.push_back (value_node);
246251 }
247- new_args.push_back (address_of_dst);
248- new_args.push_back (address_of_value);
252+
249253 new_args.push_back (memory_order);
250254
251255 Call new_call =
You can’t perform that action at this time.
0 commit comments