@@ -44,6 +44,7 @@ void myHipMemsetD64(void* dst, unsigned long long value, size_t count) {
4444 count);
4545}
4646
47+ // Legacy API: Using SymmMemObjPtr + offset
4748template <typename T>
4849__global__ void AtomicFetchThreadKernel (int myPe, const SymmMemObjPtr memObj) {
4950 constexpr int sendPe = 0 ;
@@ -56,16 +57,38 @@ __global__ void AtomicFetchThreadKernel(int myPe, const SymmMemObjPtr memObj) {
5657 T ret = ShmemAtomicTypeFetchThread<T>(memObj, 2 * sizeof (T), 1 , 0 , AMO_FETCH_ADD, recvPe);
5758 __threadfence_system ();
5859 if (ret == gridDim.x * blockDim.x ) {
59- printf (" globalTid: %d ret = %lu atomic fetch is ok!~\n " , globalTid, (uint64_t )ret);
60+ printf (" Legacy API: globalTid: %d ret = %lu atomic fetch is ok!~\n " , globalTid, (uint64_t )ret);
6061 }
61-
62- // __syncthreads();
6362 } else {
6463 while (AtomicLoadRelaxed (reinterpret_cast <T*>(memObj->localPtr ) + 2 ) !=
6564 gridDim.x * blockDim.x + 1 ) {
6665 }
6766 if (globalTid == 0 ) {
68- printf (" atomic fetch is ok!~\n " );
67+ printf (" Legacy API: atomic fetch is ok!~\n " );
68+ }
69+ }
70+ }
71+
72+ // New API: Using pure addresses
73+ template <typename T>
74+ __global__ void AtomicFetchThreadKernel_PureAddr (int myPe, T* localBuff) {
75+ constexpr int sendPe = 0 ;
76+ constexpr int recvPe = 1 ;
77+
78+ int globalTid = blockIdx.x * blockDim.x + threadIdx.x ;
79+
80+ if (myPe == sendPe) {
81+ T* dest = localBuff + 2 ;
82+ T ret = ShmemAtomicTypeFetchThread<T>(dest, 1 , 0 , AMO_FETCH_ADD, recvPe);
83+ __threadfence_system ();
84+ if (ret == gridDim.x * blockDim.x ) {
85+ printf (" Pure Address API: globalTid: %d ret = %lu atomic fetch is ok!~\n " , globalTid, (uint64_t )ret);
86+ }
87+ } else {
88+ while (AtomicLoadRelaxed (localBuff + 2 ) != gridDim.x * blockDim.x + 1 ) {
89+ }
90+ if (globalTid == 0 ) {
91+ printf (" Pure Address API: atomic fetch is ok!~\n " );
6992 }
7093 }
7194}
@@ -89,76 +112,175 @@ void testAtomicFetchThread() {
89112 int numEle = threadNum * blockNum;
90113 int buffSize = numEle * sizeof (uint64_t );
91114
115+ if (myPe == 0 ) {
116+ printf (" =================================================================\n " );
117+ printf (" Testing both Legacy and Pure Address APIs (Atomic Fetch)\n " );
118+ printf (" =================================================================\n " );
119+ }
120+
92121 void * buff = ShmemMalloc (buffSize);
93- myHipMemsetD64 (buff, myPe, numEle);
94- HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
95- printf (" before rank[%d] %lu %lu\n " , myPe, *(reinterpret_cast <uint64_t *>(buff)),
96- *(reinterpret_cast <uint64_t *>(buff) + numEle - 1 ));
97122 SymmMemObjPtr buffObj = ShmemQueryMemObjPtr (buff);
98123 assert (buffObj.IsValid ());
99124
100- for (int iteration = 0 ; iteration < 10 ; iteration++) {
125+ // Run atomic fetch operations for different types
126+ for (int iteration = 0 ; iteration < 3 ; iteration++) {
101127 if (myPe == 0 ) {
102- printf (" ========== Iteration %d ==========\n " , iteration + 1 );
128+ printf (" \n ========== Iteration %d ==========\n " , iteration + 1 );
103129 }
104130
105- // Run uint64 atomic nonfetch
131+ // ===== Test 1: Legacy API with uint64_t =====
132+ if (myPe == 0 ) {
133+ printf (" \n --- Test 1: Legacy API (uint64_t) ---\n " );
134+ }
106135 myHipMemsetD64 (buff, myPe, numEle);
107136 HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
108- printf ( " before rank[%d] uint64: %lu %lu \n " , myPe, *( reinterpret_cast < uint64_t *>(buff)),
109- *( reinterpret_cast < uint64_t *>(buff)));
137+ MPI_Barrier (MPI_COMM_WORLD);
138+
110139 AtomicFetchThreadKernel<uint64_t ><<<blockNum, threadNum>>>(myPe, buffObj);
111140 HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
112141 MPI_Barrier (MPI_COMM_WORLD);
113- printf (" after rank[%d] uint64: %lu %lu\n " , myPe, *(reinterpret_cast <uint64_t *>(buff)),
114- *(reinterpret_cast <uint64_t *>(buff) + 2 ));
142+
143+ if (myPe == 0 ) {
144+ uint64_t result = *(reinterpret_cast <uint64_t *>(buff) + 2 );
145+ printf (" ✓ Legacy API uint64_t test completed. Result at index 2: %lu\n " , result);
146+ }
115147
116- // Test int64_t atomic nonfetch
117- buffSize = numEle * sizeof (int64_t );
148+ // ===== Test 2: Pure Address API with uint64_t =====
149+ if (myPe == 0 ) {
150+ printf (" \n --- Test 2: Pure Address API (uint64_t) ---\n " );
151+ }
118152 myHipMemsetD64 (buff, myPe, numEle);
119153 HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
120- printf (" before rank[%d] int64: %ld %ld\n " , myPe, *(reinterpret_cast <int64_t *>(buff)),
121- *(reinterpret_cast <int64_t *>(buff)));
122- // Run int64 atomic nonfetch
154+ MPI_Barrier (MPI_COMM_WORLD);
155+
156+ AtomicFetchThreadKernel_PureAddr<uint64_t ><<<blockNum, threadNum>>>(
157+ myPe, reinterpret_cast <uint64_t *>(buff));
158+ HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
159+ MPI_Barrier (MPI_COMM_WORLD);
160+
161+ if (myPe == 0 ) {
162+ uint64_t result = *(reinterpret_cast <uint64_t *>(buff) + 2 );
163+ printf (" ✓ Pure Address API uint64_t test completed. Result at index 2: %lu\n " , result);
164+ }
165+
166+ // ===== Test 3: Legacy API with int64_t =====
167+ if (myPe == 0 ) {
168+ printf (" \n --- Test 3: Legacy API (int64_t) ---\n " );
169+ }
170+ myHipMemsetD64 (buff, myPe, numEle);
171+ HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
172+ MPI_Barrier (MPI_COMM_WORLD);
173+
123174 AtomicFetchThreadKernel<int64_t ><<<blockNum, threadNum>>>(myPe, buffObj);
124175 HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
125176 MPI_Barrier (MPI_COMM_WORLD);
126- printf (" after rank[%d] int64: %ld %ld\n " , myPe, *(reinterpret_cast <int64_t *>(buff)),
127- *(reinterpret_cast <int64_t *>(buff) + 2 ));
177+
178+ if (myPe == 0 ) {
179+ int64_t result = *(reinterpret_cast <int64_t *>(buff) + 2 );
180+ printf (" ✓ Legacy API int64_t test completed. Result at index 2: %ld\n " , result);
181+ }
128182
129- // Test uint32_t atomic nonfetch
130- buffSize = numEle * sizeof (uint32_t );
183+ // ===== Test 4: Pure Address API with int64_t =====
184+ if (myPe == 0 ) {
185+ printf (" \n --- Test 4: Pure Address API (int64_t) ---\n " );
186+ }
187+ myHipMemsetD64 (buff, myPe, numEle);
188+ HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
189+ MPI_Barrier (MPI_COMM_WORLD);
190+
191+ AtomicFetchThreadKernel_PureAddr<int64_t ><<<blockNum, threadNum>>>(
192+ myPe, reinterpret_cast <int64_t *>(buff));
193+ HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
194+ MPI_Barrier (MPI_COMM_WORLD);
195+
196+ if (myPe == 0 ) {
197+ int64_t result = *(reinterpret_cast <int64_t *>(buff) + 2 );
198+ printf (" ✓ Pure Address API int64_t test completed. Result at index 2: %ld\n " , result);
199+ }
200+
201+ // ===== Test 5: Legacy API with uint32_t =====
202+ if (myPe == 0 ) {
203+ printf (" \n --- Test 5: Legacy API (uint32_t) ---\n " );
204+ }
131205 HIP_RUNTIME_CHECK (hipMemsetD32 (reinterpret_cast <uint32_t *>(buff), myPe, numEle));
132206 HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
133- printf (" before rank[%d] uint32: %u %u\n " , myPe, *(reinterpret_cast <uint32_t *>(buff)),
134- *(reinterpret_cast <uint32_t *>(buff)));
135- // Run uint32 atomic nonfetch
207+ MPI_Barrier (MPI_COMM_WORLD);
208+
136209 AtomicFetchThreadKernel<uint32_t ><<<blockNum, threadNum>>>(myPe, buffObj);
137210 HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
138211 MPI_Barrier (MPI_COMM_WORLD);
139- printf (" after rank[%d] uint32: %u %u\n " , myPe, *(reinterpret_cast <uint32_t *>(buff)),
140- *(reinterpret_cast <uint32_t *>(buff) + 2 ));
212+
213+ if (myPe == 0 ) {
214+ uint32_t result = *(reinterpret_cast <uint32_t *>(buff) + 2 );
215+ printf (" ✓ Legacy API uint32_t test completed. Result at index 2: %u\n " , result);
216+ }
141217
142- // Test int32_t atomic nonfetch
143- buffSize = numEle * sizeof (int32_t );
218+ // ===== Test 6: Pure Address API with uint32_t =====
219+ if (myPe == 0 ) {
220+ printf (" \n --- Test 6: Pure Address API (uint32_t) ---\n " );
221+ }
222+ HIP_RUNTIME_CHECK (hipMemsetD32 (reinterpret_cast <uint32_t *>(buff), myPe, numEle));
223+ HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
224+ MPI_Barrier (MPI_COMM_WORLD);
225+
226+ AtomicFetchThreadKernel_PureAddr<uint32_t ><<<blockNum, threadNum>>>(
227+ myPe, reinterpret_cast <uint32_t *>(buff));
228+ HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
229+ MPI_Barrier (MPI_COMM_WORLD);
230+
231+ if (myPe == 0 ) {
232+ uint32_t result = *(reinterpret_cast <uint32_t *>(buff) + 2 );
233+ printf (" ✓ Pure Address API uint32_t test completed. Result at index 2: %u\n " , result);
234+ }
235+
236+ // ===== Test 7: Legacy API with int32_t =====
237+ if (myPe == 0 ) {
238+ printf (" \n --- Test 7: Legacy API (int32_t) ---\n " );
239+ }
144240 HIP_RUNTIME_CHECK (hipMemsetD32 (reinterpret_cast <int32_t *>(buff), myPe, numEle));
145241 HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
146- printf (" before rank[%d] int32: %d %d\n " , myPe, *(reinterpret_cast <int32_t *>(buff)),
147- *(reinterpret_cast <int32_t *>(buff)));
148- // Run int32 atomic nonfetch
242+ MPI_Barrier (MPI_COMM_WORLD);
243+
149244 AtomicFetchThreadKernel<int32_t ><<<blockNum, threadNum>>>(myPe, buffObj);
150245 HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
151246 MPI_Barrier (MPI_COMM_WORLD);
152- printf (" after rank[%d] int32: %d %d\n " , myPe, *(reinterpret_cast <int32_t *>(buff)),
153- *(reinterpret_cast <int32_t *>(buff) + 2 ));
247+
248+ if (myPe == 0 ) {
249+ int32_t result = *(reinterpret_cast <int32_t *>(buff) + 2 );
250+ printf (" ✓ Legacy API int32_t test completed. Result at index 2: %d\n " , result);
251+ }
252+
253+ // ===== Test 8: Pure Address API with int32_t =====
254+ if (myPe == 0 ) {
255+ printf (" \n --- Test 8: Pure Address API (int32_t) ---\n " );
256+ }
257+ HIP_RUNTIME_CHECK (hipMemsetD32 (reinterpret_cast <int32_t *>(buff), myPe, numEle));
258+ HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
259+ MPI_Barrier (MPI_COMM_WORLD);
154260
155- MPI_Barrier (MPI_COMM_WORLD); // Ensure all processes complete this iteration before next
261+ AtomicFetchThreadKernel_PureAddr<int32_t ><<<blockNum, threadNum>>>(
262+ myPe, reinterpret_cast <int32_t *>(buff));
263+ HIP_RUNTIME_CHECK (hipDeviceSynchronize ());
264+ MPI_Barrier (MPI_COMM_WORLD);
265+
156266 if (myPe == 0 ) {
157- printf (" Iteration %d completed\n " , iteration + 1 );
267+ int32_t result = *(reinterpret_cast <int32_t *>(buff) + 2 );
268+ printf (" ✓ Pure Address API int32_t test completed. Result at index 2: %d\n " , result);
269+ }
270+
271+ MPI_Barrier (MPI_COMM_WORLD);
272+ if (myPe == 0 ) {
273+ printf (" \n Iteration %d completed successfully!\n " , iteration + 1 );
158274 }
159275 sleep (1 );
160276 }
161277
278+ if (myPe == 0 ) {
279+ printf (" \n =================================================================\n " );
280+ printf (" All Atomic Fetch tests completed!\n " );
281+ printf (" =================================================================\n " );
282+ }
283+
162284 // Finalize
163285 ShmemFree (buff);
164286 ShmemFinalize ();
0 commit comments