Skip to content

Commit 5f3d7cf

Browse files
committed
Feature: support symmetric heap and unified memory space
1 parent d80bd08 commit 5f3d7cf

15 files changed

+2648
-208
lines changed

examples/ops/dispatch_combine/test_dispatch_combine_internode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch.distributed as dist
2727
import argparse
2828

29+
os.environ["MORI_SHMEM_HEAP_SIZE"] = "4G"
2930

3031
kernel_type_map = {
3132
"v0": mori.ops.EpDispatchCombineKernelType.InterNode,

examples/shmem/atomic_fetch_thread.cpp

Lines changed: 160 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ void myHipMemsetD64(void* dst, unsigned long long value, size_t count) {
4444
count);
4545
}
4646

47+
// Legacy API: Using SymmMemObjPtr + offset
4748
template <typename T>
4849
__global__ void AtomicFetchThreadKernel(int myPe, const SymmMemObjPtr memObj) {
4950
constexpr int sendPe = 0;
@@ -56,16 +57,38 @@ __global__ void AtomicFetchThreadKernel(int myPe, const SymmMemObjPtr memObj) {
5657
T ret = ShmemAtomicTypeFetchThread<T>(memObj, 2 * sizeof(T), 1, 0, AMO_FETCH_ADD, recvPe);
5758
__threadfence_system();
5859
if (ret == gridDim.x * blockDim.x) {
59-
printf("globalTid: %d ret = %lu atomic fetch is ok!~\n", globalTid, (uint64_t)ret);
60+
printf("Legacy API: globalTid: %d ret = %lu atomic fetch is ok!~\n", globalTid, (uint64_t)ret);
6061
}
61-
62-
// __syncthreads();
6362
} else {
6463
while (AtomicLoadRelaxed(reinterpret_cast<T*>(memObj->localPtr) + 2) !=
6564
gridDim.x * blockDim.x + 1) {
6665
}
6766
if (globalTid == 0) {
68-
printf("atomic fetch is ok!~\n");
67+
printf("Legacy API: atomic fetch is ok!~\n");
68+
}
69+
}
70+
}
71+
72+
// New API: Using pure addresses
73+
template <typename T>
74+
__global__ void AtomicFetchThreadKernel_PureAddr(int myPe, T* localBuff) {
75+
constexpr int sendPe = 0;
76+
constexpr int recvPe = 1;
77+
78+
int globalTid = blockIdx.x * blockDim.x + threadIdx.x;
79+
80+
if (myPe == sendPe) {
81+
T* dest = localBuff + 2;
82+
T ret = ShmemAtomicTypeFetchThread<T>(dest, 1, 0, AMO_FETCH_ADD, recvPe);
83+
__threadfence_system();
84+
if (ret == gridDim.x * blockDim.x) {
85+
printf("Pure Address API: globalTid: %d ret = %lu atomic fetch is ok!~\n", globalTid, (uint64_t)ret);
86+
}
87+
} else {
88+
while (AtomicLoadRelaxed(localBuff + 2) != gridDim.x * blockDim.x + 1) {
89+
}
90+
if (globalTid == 0) {
91+
printf("Pure Address API: atomic fetch is ok!~\n");
6992
}
7093
}
7194
}
@@ -89,76 +112,175 @@ void testAtomicFetchThread() {
89112
int numEle = threadNum * blockNum;
90113
int buffSize = numEle * sizeof(uint64_t);
91114

115+
if (myPe == 0) {
116+
printf("=================================================================\n");
117+
printf("Testing both Legacy and Pure Address APIs (Atomic Fetch)\n");
118+
printf("=================================================================\n");
119+
}
120+
92121
void* buff = ShmemMalloc(buffSize);
93-
myHipMemsetD64(buff, myPe, numEle);
94-
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
95-
printf("before rank[%d] %lu %lu\n", myPe, *(reinterpret_cast<uint64_t*>(buff)),
96-
*(reinterpret_cast<uint64_t*>(buff) + numEle - 1));
97122
SymmMemObjPtr buffObj = ShmemQueryMemObjPtr(buff);
98123
assert(buffObj.IsValid());
99124

100-
for (int iteration = 0; iteration < 10; iteration++) {
125+
// Run atomic fetch operations for different types
126+
for (int iteration = 0; iteration < 3; iteration++) {
101127
if (myPe == 0) {
102-
printf("========== Iteration %d ==========\n", iteration + 1);
128+
printf("\n========== Iteration %d ==========\n", iteration + 1);
103129
}
104130

105-
// Run uint64 atomic nonfetch
131+
// ===== Test 1: Legacy API with uint64_t =====
132+
if (myPe == 0) {
133+
printf("\n--- Test 1: Legacy API (uint64_t) ---\n");
134+
}
106135
myHipMemsetD64(buff, myPe, numEle);
107136
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
108-
printf("before rank[%d] uint64: %lu %lu\n", myPe, *(reinterpret_cast<uint64_t*>(buff)),
109-
*(reinterpret_cast<uint64_t*>(buff)));
137+
MPI_Barrier(MPI_COMM_WORLD);
138+
110139
AtomicFetchThreadKernel<uint64_t><<<blockNum, threadNum>>>(myPe, buffObj);
111140
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
112141
MPI_Barrier(MPI_COMM_WORLD);
113-
printf("after rank[%d] uint64: %lu %lu\n", myPe, *(reinterpret_cast<uint64_t*>(buff)),
114-
*(reinterpret_cast<uint64_t*>(buff) + 2));
142+
143+
if (myPe == 0) {
144+
uint64_t result = *(reinterpret_cast<uint64_t*>(buff) + 2);
145+
printf("✓ Legacy API uint64_t test completed. Result at index 2: %lu\n", result);
146+
}
115147

116-
// Test int64_t atomic nonfetch
117-
buffSize = numEle * sizeof(int64_t);
148+
// ===== Test 2: Pure Address API with uint64_t =====
149+
if (myPe == 0) {
150+
printf("\n--- Test 2: Pure Address API (uint64_t) ---\n");
151+
}
118152
myHipMemsetD64(buff, myPe, numEle);
119153
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
120-
printf("before rank[%d] int64: %ld %ld\n", myPe, *(reinterpret_cast<int64_t*>(buff)),
121-
*(reinterpret_cast<int64_t*>(buff)));
122-
// Run int64 atomic nonfetch
154+
MPI_Barrier(MPI_COMM_WORLD);
155+
156+
AtomicFetchThreadKernel_PureAddr<uint64_t><<<blockNum, threadNum>>>(
157+
myPe, reinterpret_cast<uint64_t*>(buff));
158+
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
159+
MPI_Barrier(MPI_COMM_WORLD);
160+
161+
if (myPe == 0) {
162+
uint64_t result = *(reinterpret_cast<uint64_t*>(buff) + 2);
163+
printf("✓ Pure Address API uint64_t test completed. Result at index 2: %lu\n", result);
164+
}
165+
166+
// ===== Test 3: Legacy API with int64_t =====
167+
if (myPe == 0) {
168+
printf("\n--- Test 3: Legacy API (int64_t) ---\n");
169+
}
170+
myHipMemsetD64(buff, myPe, numEle);
171+
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
172+
MPI_Barrier(MPI_COMM_WORLD);
173+
123174
AtomicFetchThreadKernel<int64_t><<<blockNum, threadNum>>>(myPe, buffObj);
124175
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
125176
MPI_Barrier(MPI_COMM_WORLD);
126-
printf("after rank[%d] int64: %ld %ld\n", myPe, *(reinterpret_cast<int64_t*>(buff)),
127-
*(reinterpret_cast<int64_t*>(buff) + 2));
177+
178+
if (myPe == 0) {
179+
int64_t result = *(reinterpret_cast<int64_t*>(buff) + 2);
180+
printf("✓ Legacy API int64_t test completed. Result at index 2: %ld\n", result);
181+
}
128182

129-
// Test uint32_t atomic nonfetch
130-
buffSize = numEle * sizeof(uint32_t);
183+
// ===== Test 4: Pure Address API with int64_t =====
184+
if (myPe == 0) {
185+
printf("\n--- Test 4: Pure Address API (int64_t) ---\n");
186+
}
187+
myHipMemsetD64(buff, myPe, numEle);
188+
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
189+
MPI_Barrier(MPI_COMM_WORLD);
190+
191+
AtomicFetchThreadKernel_PureAddr<int64_t><<<blockNum, threadNum>>>(
192+
myPe, reinterpret_cast<int64_t*>(buff));
193+
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
194+
MPI_Barrier(MPI_COMM_WORLD);
195+
196+
if (myPe == 0) {
197+
int64_t result = *(reinterpret_cast<int64_t*>(buff) + 2);
198+
printf("✓ Pure Address API int64_t test completed. Result at index 2: %ld\n", result);
199+
}
200+
201+
// ===== Test 5: Legacy API with uint32_t =====
202+
if (myPe == 0) {
203+
printf("\n--- Test 5: Legacy API (uint32_t) ---\n");
204+
}
131205
HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast<uint32_t*>(buff), myPe, numEle));
132206
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
133-
printf("before rank[%d] uint32: %u %u\n", myPe, *(reinterpret_cast<uint32_t*>(buff)),
134-
*(reinterpret_cast<uint32_t*>(buff)));
135-
// Run uint32 atomic nonfetch
207+
MPI_Barrier(MPI_COMM_WORLD);
208+
136209
AtomicFetchThreadKernel<uint32_t><<<blockNum, threadNum>>>(myPe, buffObj);
137210
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
138211
MPI_Barrier(MPI_COMM_WORLD);
139-
printf("after rank[%d] uint32: %u %u\n", myPe, *(reinterpret_cast<uint32_t*>(buff)),
140-
*(reinterpret_cast<uint32_t*>(buff) + 2));
212+
213+
if (myPe == 0) {
214+
uint32_t result = *(reinterpret_cast<uint32_t*>(buff) + 2);
215+
printf("✓ Legacy API uint32_t test completed. Result at index 2: %u\n", result);
216+
}
141217

142-
// Test int32_t atomic nonfetch
143-
buffSize = numEle * sizeof(int32_t);
218+
// ===== Test 6: Pure Address API with uint32_t =====
219+
if (myPe == 0) {
220+
printf("\n--- Test 6: Pure Address API (uint32_t) ---\n");
221+
}
222+
HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast<uint32_t*>(buff), myPe, numEle));
223+
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
224+
MPI_Barrier(MPI_COMM_WORLD);
225+
226+
AtomicFetchThreadKernel_PureAddr<uint32_t><<<blockNum, threadNum>>>(
227+
myPe, reinterpret_cast<uint32_t*>(buff));
228+
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
229+
MPI_Barrier(MPI_COMM_WORLD);
230+
231+
if (myPe == 0) {
232+
uint32_t result = *(reinterpret_cast<uint32_t*>(buff) + 2);
233+
printf("✓ Pure Address API uint32_t test completed. Result at index 2: %u\n", result);
234+
}
235+
236+
// ===== Test 7: Legacy API with int32_t =====
237+
if (myPe == 0) {
238+
printf("\n--- Test 7: Legacy API (int32_t) ---\n");
239+
}
144240
HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast<int32_t*>(buff), myPe, numEle));
145241
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
146-
printf("before rank[%d] int32: %d %d\n", myPe, *(reinterpret_cast<int32_t*>(buff)),
147-
*(reinterpret_cast<int32_t*>(buff)));
148-
// Run int32 atomic nonfetch
242+
MPI_Barrier(MPI_COMM_WORLD);
243+
149244
AtomicFetchThreadKernel<int32_t><<<blockNum, threadNum>>>(myPe, buffObj);
150245
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
151246
MPI_Barrier(MPI_COMM_WORLD);
152-
printf("after rank[%d] int32: %d %d\n", myPe, *(reinterpret_cast<int32_t*>(buff)),
153-
*(reinterpret_cast<int32_t*>(buff) + 2));
247+
248+
if (myPe == 0) {
249+
int32_t result = *(reinterpret_cast<int32_t*>(buff) + 2);
250+
printf("✓ Legacy API int32_t test completed. Result at index 2: %d\n", result);
251+
}
252+
253+
// ===== Test 8: Pure Address API with int32_t =====
254+
if (myPe == 0) {
255+
printf("\n--- Test 8: Pure Address API (int32_t) ---\n");
256+
}
257+
HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast<int32_t*>(buff), myPe, numEle));
258+
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
259+
MPI_Barrier(MPI_COMM_WORLD);
154260

155-
MPI_Barrier(MPI_COMM_WORLD); // Ensure all processes complete this iteration before next
261+
AtomicFetchThreadKernel_PureAddr<int32_t><<<blockNum, threadNum>>>(
262+
myPe, reinterpret_cast<int32_t*>(buff));
263+
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
264+
MPI_Barrier(MPI_COMM_WORLD);
265+
156266
if (myPe == 0) {
157-
printf("Iteration %d completed\n", iteration + 1);
267+
int32_t result = *(reinterpret_cast<int32_t*>(buff) + 2);
268+
printf("✓ Pure Address API int32_t test completed. Result at index 2: %d\n", result);
269+
}
270+
271+
MPI_Barrier(MPI_COMM_WORLD);
272+
if (myPe == 0) {
273+
printf("\nIteration %d completed successfully!\n", iteration + 1);
158274
}
159275
sleep(1);
160276
}
161277

278+
if (myPe == 0) {
279+
printf("\n=================================================================\n");
280+
printf("All Atomic Fetch tests completed!\n");
281+
printf("=================================================================\n");
282+
}
283+
162284
// Finalize
163285
ShmemFree(buff);
164286
ShmemFinalize();

0 commit comments

Comments
 (0)