@@ -41,9 +41,10 @@ template <> TL_DEVICE __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) {
4141#endif 
4242
4343template  <typename  T1, typename  T2>
44- TL_DEVICE void  AtomicMax (T1 *address , T2 val,
44+ TL_DEVICE void  AtomicMax (T1 &ref , T2 val,
4545                         int  memory_order = int (cuda::memory_order_relaxed)) {
4646  using  NT1 = typename  normalize_atomic_type<T1>::type;
47+   T1 *address = &ref;
4748  if  constexpr  (std::is_same_v<NT1, half> ||
4849                std::is_same_v<NT1, __nv_bfloat16>) {
4950    atomicMax (reinterpret_cast <NT1 *>(address), static_cast <NT1>(val));
@@ -54,9 +55,10 @@ TL_DEVICE void AtomicMax(T1 *address, T2 val,
5455}
5556
5657template  <typename  T1, typename  T2>
57- TL_DEVICE T1 AtomicMaxRet (T1 *address , T2 val,
58+ TL_DEVICE T1 AtomicMaxRet (T1 &ref , T2 val,
5859                          int  memory_order = int (cuda::memory_order_relaxed)) {
5960  using  NT1 = typename  normalize_atomic_type<T1>::type;
61+   T1 *address = &ref;
6062  if  constexpr  (std::is_same_v<NT1, half> ||
6163                std::is_same_v<NT1, __nv_bfloat16>) {
6264    return  static_cast <T1>(
@@ -69,9 +71,10 @@ TL_DEVICE T1 AtomicMaxRet(T1 *address, T2 val,
6971}
7072
7173template  <typename  T1, typename  T2>
72- TL_DEVICE void  AtomicMin (T1 *address , T2 val,
74+ TL_DEVICE void  AtomicMin (T1 &ref , T2 val,
7375                         int  memory_order = int (cuda::memory_order_relaxed)) {
7476  using  NT1 = typename  normalize_atomic_type<T1>::type;
77+   T1 *address = &ref;
7578  if  constexpr  (std::is_same_v<NT1, half> ||
7679                std::is_same_v<NT1, __nv_bfloat16>) {
7780    atomicMin (reinterpret_cast <NT1 *>(address), static_cast <NT1>(val));
@@ -82,9 +85,10 @@ TL_DEVICE void AtomicMin(T1 *address, T2 val,
8285}
8386
8487template  <typename  T1, typename  T2>
85- TL_DEVICE T1 AtomicMinRet (T1 *address , T2 val,
88+ TL_DEVICE T1 AtomicMinRet (T1 &ref , T2 val,
8689                          int  memory_order = int (cuda::memory_order_relaxed)) {
8790  using  NT1 = typename  normalize_atomic_type<T1>::type;
91+   T1 *address = &ref;
8892  if  constexpr  (std::is_same_v<NT1, half> ||
8993                std::is_same_v<NT1, __nv_bfloat16>) {
9094    return  static_cast <T1>(
@@ -97,9 +101,10 @@ TL_DEVICE T1 AtomicMinRet(T1 *address, T2 val,
97101}
98102
99103template  <typename  T1, typename  T2>
100- TL_DEVICE void  AtomicAdd (T1 *address , T2 val,
104+ TL_DEVICE void  AtomicAdd (T1 &ref , T2 val,
101105                         int  memory_order = int (cuda::memory_order_relaxed)) {
102106  using  NT1 = typename  normalize_atomic_type<T1>::type;
107+   T1 *address = &ref;
103108  if  constexpr  (std::is_same_v<NT1, half> ||
104109                std::is_same_v<NT1, __nv_bfloat16>) {
105110    atomicAdd (reinterpret_cast <NT1 *>(address), static_cast <NT1>(val));
@@ -110,9 +115,10 @@ TL_DEVICE void AtomicAdd(T1 *address, T2 val,
110115}
111116
112117template  <typename  T1, typename  T2>
113- TL_DEVICE T1 AtomicAddRet (T1 *address , T2 val,
118+ TL_DEVICE T1 AtomicAddRet (T1 &ref , T2 val,
114119                          int  memory_order = int (cuda::memory_order_relaxed)) {
115120  using  NT1 = typename  normalize_atomic_type<T1>::type;
121+   T1 *address = &ref;
116122  if  constexpr  (std::is_same_v<NT1, half> ||
117123                std::is_same_v<NT1, __nv_bfloat16>) {
118124    return  static_cast <T1>(
@@ -124,60 +130,60 @@ TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val,
124130  }
125131}
126132
127- TL_DEVICE void  AtomicAddx2 (half_t  *address , half_t  *val) {
128-   atomicAdd (reinterpret_cast <half2 *>(address ),
133+ TL_DEVICE void  AtomicAddx2 (half_t  *ref , half_t  *val) {
134+   atomicAdd (reinterpret_cast <half2 *>(ref ),
129135            static_cast <half2>(*reinterpret_cast <half2 *>(val)));
130136}
131137
132- TL_DEVICE half2 AtomicAddx2Ret (half_t  *address , half_t  *val) {
133-   return  atomicAdd (reinterpret_cast <half2 *>(address ),
138+ TL_DEVICE half2 AtomicAddx2Ret (half_t  *ref , half_t  *val) {
139+   return  atomicAdd (reinterpret_cast <half2 *>(ref ),
134140                   static_cast <half2>(*reinterpret_cast <half2 *>(val)));
135141}
136142
137143#if  (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
138- TL_DEVICE void  AtomicAddx2 (bfloat16_t  *address , bfloat16_t  *val) {
144+ TL_DEVICE void  AtomicAddx2 (bfloat16_t  *ref , bfloat16_t  *val) {
139145  atomicAdd (
140-       reinterpret_cast <__nv_bfloat162 *>(address ),
146+       reinterpret_cast <__nv_bfloat162 *>(ref ),
141147      static_cast <__nv_bfloat162>(*reinterpret_cast <__nv_bfloat162 *>(val)));
142148}
143149
144- TL_DEVICE __nv_bfloat162 AtomicAddx2Ret (bfloat16_t  *address , bfloat16_t  *val) {
150+ TL_DEVICE __nv_bfloat162 AtomicAddx2Ret (bfloat16_t  *ref , bfloat16_t  *val) {
145151  return  atomicAdd (
146-       reinterpret_cast <__nv_bfloat162 *>(address ),
152+       reinterpret_cast <__nv_bfloat162 *>(ref ),
147153      static_cast <__nv_bfloat162>(*reinterpret_cast <__nv_bfloat162 *>(val)));
148154}
149155#endif 
150156
151157#if  (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
152- TL_DEVICE void  AtomicAddx2 (float  *address , float  *val) {
153-   atomicAdd (reinterpret_cast <float2 *>(address ),
158+ TL_DEVICE void  AtomicAddx2 (float  *ref , float  *val) {
159+   atomicAdd (reinterpret_cast <float2 *>(ref ),
154160            static_cast <float2>(*reinterpret_cast <float2 *>(val)));
155161}
156162
157- TL_DEVICE float2 AtomicAddx2Ret (float  *address , float  *val) {
158-   return  atomicAdd (reinterpret_cast <float2 *>(address ),
163+ TL_DEVICE float2 AtomicAddx2Ret (float  *ref , float  *val) {
164+   return  atomicAdd (reinterpret_cast <float2 *>(ref ),
159165                   static_cast <float2>(*reinterpret_cast <float2 *>(val)));
160166}
161167
162- TL_DEVICE void  AtomicAddx4 (float  *address , float  *val) {
163-   atomicAdd (reinterpret_cast <float4 *>(address ),
168+ TL_DEVICE void  AtomicAddx4 (float  *ref , float  *val) {
169+   atomicAdd (reinterpret_cast <float4 *>(ref ),
164170            static_cast <float4>(*reinterpret_cast <float4 *>(val)));
165171}
166172
167- TL_DEVICE float4 AtomicAddx4Ret (float  *address , float  *val) {
168-   return  atomicAdd (reinterpret_cast <float4 *>(address ),
173+ TL_DEVICE float4 AtomicAddx4Ret (float  *ref , float  *val) {
174+   return  atomicAdd (reinterpret_cast <float4 *>(ref ),
169175                   static_cast <float4>(*reinterpret_cast <float4 *>(val)));
170176}
171177#endif 
172178
173- template  <typename  T> TL_DEVICE T AtomicLoad (T *address , int  memory_order) {
174-   cuda::atomic_ref<T, cuda::thread_scope_device> aref (*address );
179+ template  <typename  T> TL_DEVICE T AtomicLoad (T &ref , int  memory_order) {
180+   cuda::atomic_ref<T, cuda::thread_scope_device> aref (ref );
175181  return  aref.load (cuda::memory_order (memory_order));
176182}
177183
178184template  <typename  T1, typename  T2>
179- TL_DEVICE void  AtomicStore (T1 *address , T2 value, int  memory_order) {
185+ TL_DEVICE void  AtomicStore (T1 &ref , T2 value, int  memory_order) {
180186  using  NT1 = typename  normalize_atomic_type<T1>::type;
181-   cuda::atomic_ref<NT1, cuda::thread_scope_device> aref (*address );
187+   cuda::atomic_ref<NT1, cuda::thread_scope_device> aref (ref );
182188  aref.store (cuda_cast<NT1>(value), cuda::memory_order (memory_order));
183- }
189+ }
0 commit comments