@@ -105,8 +105,9 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
105105 int memory_order = int (cuda::memory_order_relaxed)) {
106106 using NT1 = typename normalize_atomic_type<T1>::type;
107107 T1 *address = &ref;
108- if constexpr (std::is_same_v<NT1, half> ||
109- std::is_same_v<NT1, __nv_bfloat16>) {
108+ if constexpr ((std::is_same_v<NT1, half> ||
109+ std::is_same_v<NT1, __nv_bfloat16>)&&memory_order ==
110+ int (cuda::memory_order_relaxed)) {
110111 atomicAdd (reinterpret_cast <NT1 *>(address), static_cast <NT1>(val));
111112 } else {
112113 cuda::atomic_ref<NT1, cuda::thread_scope_device> aref (*address);
@@ -119,8 +120,9 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
119120 int memory_order = int (cuda::memory_order_relaxed)) {
120121 using NT1 = typename normalize_atomic_type<T1>::type;
121122 T1 *address = &ref;
122- if constexpr (std::is_same_v<NT1, half> ||
123- std::is_same_v<NT1, __nv_bfloat16>) {
123+ if constexpr ((std::is_same_v<NT1, half> ||
124+ std::is_same_v<NT1, __nv_bfloat16>)&&memory_order ==
125+ int (cuda::memory_order_relaxed)) {
124126 return static_cast <T1>(
125127 atomicAdd (reinterpret_cast <NT1 *>(address), static_cast <NT1>(val)));
126128 } else {
@@ -130,47 +132,60 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
130132 }
131133}
132134
133- TL_DEVICE void AtomicAddx2 (half_t *ref, half_t *val) {
135+ // TODO add memory_order for vectorized atomic add
136+ TL_DEVICE void AtomicAddx2 (half_t *ref, half_t *val,
137+ int memory_order = int (cuda::memory_order_relaxed)) {
134138 atomicAdd (reinterpret_cast <half2 *>(ref),
135139 static_cast <half2>(*reinterpret_cast <half2 *>(val)));
136140}
137141
138- TL_DEVICE half2 AtomicAddx2Ret (half_t *ref, half_t *val) {
142+ TL_DEVICE half2
143+ AtomicAddx2Ret (half_t *ref, half_t *val,
144+ int memory_order = int (cuda::memory_order_relaxed)) {
139145 return atomicAdd (reinterpret_cast <half2 *>(ref),
140146 static_cast <half2>(*reinterpret_cast <half2 *>(val)));
141147}
142148
143149#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
144- TL_DEVICE void AtomicAddx2 (bfloat16_t *ref, bfloat16_t *val) {
150+ TL_DEVICE void AtomicAddx2 (bfloat16_t *ref, bfloat16_t *val,
151+ int memory_order = int (cuda::memory_order_relaxed)) {
145152 atomicAdd (
146153 reinterpret_cast <__nv_bfloat162 *>(ref),
147154 static_cast <__nv_bfloat162>(*reinterpret_cast <__nv_bfloat162 *>(val)));
148155}
149156
150- TL_DEVICE __nv_bfloat162 AtomicAddx2Ret (bfloat16_t *ref, bfloat16_t *val) {
157+ TL_DEVICE __nv_bfloat162
158+ AtomicAddx2Ret (bfloat16_t *ref, bfloat16_t *val,
159+ int memory_order = int (cuda::memory_order_relaxed)) {
151160 return atomicAdd (
152161 reinterpret_cast <__nv_bfloat162 *>(ref),
153162 static_cast <__nv_bfloat162>(*reinterpret_cast <__nv_bfloat162 *>(val)));
154163}
155164#endif
156165
157166#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
158- TL_DEVICE void AtomicAddx2 (float *ref, float *val) {
167+ TL_DEVICE void AtomicAddx2 (float *ref, float *val,
168+ int memory_order = int (cuda::memory_order_relaxed)) {
159169 atomicAdd (reinterpret_cast <float2 *>(ref),
160170 static_cast <float2>(*reinterpret_cast <float2 *>(val)));
161171}
162172
163- TL_DEVICE float2 AtomicAddx2Ret (float *ref, float *val) {
173+ TL_DEVICE float2
174+ AtomicAddx2Ret (float *ref, float *val,
175+ int memory_order = int (cuda::memory_order_relaxed)) {
164176 return atomicAdd (reinterpret_cast <float2 *>(ref),
165177 static_cast <float2>(*reinterpret_cast <float2 *>(val)));
166178}
167179
168- TL_DEVICE void AtomicAddx4 (float *ref, float *val) {
180+ TL_DEVICE void AtomicAddx4 (float *ref, float *val,
181+ int memory_order = int (cuda::memory_order_relaxed)) {
169182 atomicAdd (reinterpret_cast <float4 *>(ref),
170183 static_cast <float4>(*reinterpret_cast <float4 *>(val)));
171184}
172185
173- TL_DEVICE float4 AtomicAddx4Ret (float *ref, float *val) {
186+ TL_DEVICE float4
187+ AtomicAddx4Ret (float *ref, float *val,
188+ int memory_order = int (cuda::memory_order_relaxed)) {
174189 return atomicAdd (reinterpret_cast <float4 *>(ref),
175190 static_cast <float4>(*reinterpret_cast <float4 *>(val)));
176191}
0 commit comments