Skip to content

Commit 1b20639

Browse files
authored
No repeated IPC open (#2642)
1 parent b72af8f commit 1b20639

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

csrc/custom_all_reduce.cuh

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <iostream>
99
#include <limits>
10+
#include <map>
1011
#include <unordered_map>
1112
#include <vector>
1213

@@ -327,6 +328,10 @@ __global__ void __launch_bounds__(512, 1)
327328
}
328329
}
329330

331+
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
332+
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
333+
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
334+
330335
class CustomAllreduce {
331336
public:
332337
int rank_;
@@ -341,7 +346,8 @@ class CustomAllreduce {
341346
// stores the registered device pointers from all ranks
342347
RankData *d_rank_data_base_, *d_rank_data_end_;
343348
std::vector<void *> graph_unreg_buffers_;
344-
std::vector<void *> ipc_handles_;
349+
// a map from IPC handles to opened IPC pointers
350+
std::map<IPC_KEY, char *> ipc_handles_;
345351

346352
/**
347353
* meta is a pointer to device metadata and temporary buffer for allreduce.
@@ -365,10 +371,7 @@ class CustomAllreduce {
365371
for (int i = 0; i < world_size_; i++) {
366372
Metadata *rank_meta;
367373
if (i != rank_) {
368-
char *handle;
369-
CUDACHECK(cudaIpcOpenMemHandle((void **)&handle, handles[i],
370-
cudaIpcMemLazyEnablePeerAccess));
371-
ipc_handles_.push_back(handle);
374+
char *handle = open_ipc_handle(&handles[i]);
372375
handle += offsets[i];
373376
rank_meta = (Metadata *)handle;
374377
} else {
@@ -378,6 +381,19 @@ class CustomAllreduce {
378381
}
379382
}
380383

384+
char *open_ipc_handle(const void *ipc_handle) {
385+
auto [it, new_handle] =
386+
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
387+
if (new_handle) {
388+
char *ipc_ptr;
389+
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
390+
*((const cudaIpcMemHandle_t *)ipc_handle),
391+
cudaIpcMemLazyEnablePeerAccess));
392+
it->second = ipc_ptr;
393+
}
394+
return it->second;
395+
}
396+
381397
std::pair<std::vector<uint8_t>, std::vector<int64_t>>
382398
get_graph_buffer_ipc_meta() {
383399
auto num_buffers = graph_unreg_buffers_.size();
@@ -413,11 +429,7 @@ class CustomAllreduce {
413429
RankData data;
414430
for (int i = 0; i < world_size_; i++) {
415431
if (i != rank_) {
416-
char *handle;
417-
CUDACHECK(cudaIpcOpenMemHandle(
418-
(void **)&handle, *((const cudaIpcMemHandle_t *)handles[i].data()),
419-
cudaIpcMemLazyEnablePeerAccess));
420-
ipc_handles_.push_back(handle);
432+
char *handle = open_ipc_handle(handles[i].data());
421433
handle += offsets[i];
422434
data.ptrs[i] = handle;
423435
} else {
@@ -448,13 +460,8 @@ class CustomAllreduce {
448460
auto &rd = rank_data[i];
449461
for (int j = 0; j < world_size_; j++) {
450462
if (j != rank_) {
451-
char *handle;
452-
CUDACHECK(cudaIpcOpenMemHandle(
453-
(void **)&handle,
454-
*((cudaIpcMemHandle_t *)&handles[j]
455-
[i * sizeof(cudaIpcMemHandle_t)]),
456-
cudaIpcMemLazyEnablePeerAccess));
457-
ipc_handles_.push_back(handle);
463+
char *handle =
464+
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
458465
handle += offsets[j][i];
459466
rd.ptrs[j] = handle;
460467
} else {
@@ -541,7 +548,7 @@ class CustomAllreduce {
541548
}
542549

543550
~CustomAllreduce() {
544-
for (auto ptr : ipc_handles_) {
551+
for (auto [_, ptr] : ipc_handles_) {
545552
CUDACHECK(cudaIpcCloseMemHandle(ptr));
546553
}
547554
}

0 commit comments

Comments
 (0)