7
7
8
8
#include < iostream>
9
9
#include < limits>
10
+ #include < map>
10
11
#include < unordered_map>
11
12
#include < vector>
12
13
@@ -327,6 +328,10 @@ __global__ void __launch_bounds__(512, 1)
327
328
}
328
329
}
329
330
331
+ using IPC_KEY = std::array<uint8_t , sizeof (cudaIpcMemHandle_t)>;
332
+ static_assert (sizeof (IPC_KEY) == sizeof (cudaIpcMemHandle_t));
333
+ static_assert (alignof (IPC_KEY) == alignof (cudaIpcMemHandle_t));
334
+
330
335
class CustomAllreduce {
331
336
public:
332
337
int rank_;
@@ -341,7 +346,8 @@ class CustomAllreduce {
341
346
// stores the registered device pointers from all ranks
342
347
RankData *d_rank_data_base_, *d_rank_data_end_;
343
348
std::vector<void *> graph_unreg_buffers_;
344
- std::vector<void *> ipc_handles_;
349
+ // a map from IPC handles to opened IPC pointers
350
+ std::map<IPC_KEY, char *> ipc_handles_;
345
351
346
352
/* *
347
353
* meta is a pointer to device metadata and temporary buffer for allreduce.
@@ -365,10 +371,7 @@ class CustomAllreduce {
365
371
for (int i = 0 ; i < world_size_; i++) {
366
372
Metadata *rank_meta;
367
373
if (i != rank_) {
368
- char *handle;
369
- CUDACHECK (cudaIpcOpenMemHandle ((void **)&handle, handles[i],
370
- cudaIpcMemLazyEnablePeerAccess));
371
- ipc_handles_.push_back (handle);
374
+ char *handle = open_ipc_handle (&handles[i]);
372
375
handle += offsets[i];
373
376
rank_meta = (Metadata *)handle;
374
377
} else {
@@ -378,6 +381,19 @@ class CustomAllreduce {
378
381
}
379
382
}
380
383
384
+ char *open_ipc_handle (const void *ipc_handle) {
385
+ auto [it, new_handle] =
386
+ ipc_handles_.insert ({*((IPC_KEY *)ipc_handle), nullptr });
387
+ if (new_handle) {
388
+ char *ipc_ptr;
389
+ CUDACHECK (cudaIpcOpenMemHandle ((void **)&ipc_ptr,
390
+ *((const cudaIpcMemHandle_t *)ipc_handle),
391
+ cudaIpcMemLazyEnablePeerAccess));
392
+ it->second = ipc_ptr;
393
+ }
394
+ return it->second ;
395
+ }
396
+
381
397
std::pair<std::vector<uint8_t >, std::vector<int64_t >>
382
398
get_graph_buffer_ipc_meta () {
383
399
auto num_buffers = graph_unreg_buffers_.size ();
@@ -413,11 +429,7 @@ class CustomAllreduce {
413
429
RankData data;
414
430
for (int i = 0 ; i < world_size_; i++) {
415
431
if (i != rank_) {
416
- char *handle;
417
- CUDACHECK (cudaIpcOpenMemHandle (
418
- (void **)&handle, *((const cudaIpcMemHandle_t *)handles[i].data ()),
419
- cudaIpcMemLazyEnablePeerAccess));
420
- ipc_handles_.push_back (handle);
432
+ char *handle = open_ipc_handle (handles[i].data ());
421
433
handle += offsets[i];
422
434
data.ptrs [i] = handle;
423
435
} else {
@@ -448,13 +460,8 @@ class CustomAllreduce {
448
460
auto &rd = rank_data[i];
449
461
for (int j = 0 ; j < world_size_; j++) {
450
462
if (j != rank_) {
451
- char *handle;
452
- CUDACHECK (cudaIpcOpenMemHandle (
453
- (void **)&handle,
454
- *((cudaIpcMemHandle_t *)&handles[j]
455
- [i * sizeof (cudaIpcMemHandle_t)]),
456
- cudaIpcMemLazyEnablePeerAccess));
457
- ipc_handles_.push_back (handle);
463
+ char *handle =
464
+ open_ipc_handle (&handles[j][i * sizeof (cudaIpcMemHandle_t)]);
458
465
handle += offsets[j][i];
459
466
rd.ptrs [j] = handle;
460
467
} else {
@@ -541,7 +548,7 @@ class CustomAllreduce {
541
548
}
542
549
543
550
~CustomAllreduce () {
544
- for (auto ptr : ipc_handles_) {
551
+ for (auto [_, ptr] : ipc_handles_) {
545
552
CUDACHECK (cudaIpcCloseMemHandle (ptr));
546
553
}
547
554
}
0 commit comments