Skip to content

Commit 4623c67

Browse files
authored
Fix OOB (#454)
* Fix OOB * Add more assertions * More checks on channels
1 parent 987f8f0 commit 4623c67

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

csrc/deep_ep.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ Buffer::Buffer(int rank,
3535
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*);
3636

3737
// Common checks
38+
EP_STATIC_ASSERT(NUM_BUFFER_ALIGNMENT_BYTES % sizeof(int4) == 0, "Invalid alignment");
3839
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and
3940
(num_nvl_bytes <= std::numeric_limits<int>::max() or num_rdma_bytes == 0));
4041
EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and
4142
(low_latency_mode or num_rdma_bytes <= std::numeric_limits<int>::max()));
43+
EP_HOST_ASSERT(num_nvl_bytes / sizeof(int4) < std::numeric_limits<int>::max());
44+
EP_HOST_ASSERT(num_rdma_bytes / sizeof(int4) < std::numeric_limits<int>::max());
4245
EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode));
4346
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
4447
if (num_rdma_bytes > 0)
@@ -57,6 +60,10 @@ Buffer::Buffer(int rank,
5760
CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));
5861
num_device_sms = device_prop.multiProcessorCount;
5962

63+
// Number of per-channel bytes cannot be large
64+
EP_HOST_ASSERT(ceil_div<int64_t>(num_nvl_bytes, num_device_sms / 2) < std::numeric_limits<int>::max());
65+
EP_HOST_ASSERT(ceil_div<int64_t>(num_rdma_bytes, num_device_sms / 2) < std::numeric_limits<int>::max());
66+
6067
if (num_nvl_bytes > 0) {
6168
// Local IPC: alloc local memory and set local IPC handles
6269
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes));

csrc/kernels/buffer.cuh

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@ private:
1111
uint8_t* ptr;
1212

1313
public:
14-
int total_bytes;
14+
int64_t total_bytes;
1515

1616
__device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {}
1717

1818
__device__ __forceinline__ Buffer(void*& gbl_ptr, int num_elems, int offset = 0) {
1919
total_bytes = num_elems * sizeof(dtype_t);
20-
ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
21-
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
20+
ptr = static_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
21+
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
2222
}
2323

2424
__device__ __forceinline__ Buffer advance_also(void*& gbl_ptr) {
25-
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
25+
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
2626
return *this;
2727
}
2828

@@ -35,30 +35,30 @@ template <typename dtype_t, int kNumRanks = 1>
3535
struct AsymBuffer {
3636
private:
3737
uint8_t* ptrs[kNumRanks];
38-
int num_bytes;
38+
int64_t num_bytes;
3939

4040
public:
41-
int total_bytes;
41+
int64_t total_bytes;
4242

4343
__device__ __forceinline__ AsymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) {
4444
EP_STATIC_ASSERT(kNumRanks == 1, "");
4545
num_bytes = num_elems * sizeof(dtype_t);
4646

47-
int per_channel_bytes = num_bytes * num_ranks;
47+
int64_t per_channel_bytes = num_bytes * num_ranks;
4848
total_bytes = per_channel_bytes * num_sms;
49-
ptrs[0] = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
50-
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
49+
ptrs[0] = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
50+
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
5151
}
5252

5353
__device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) {
5454
EP_STATIC_ASSERT(kNumRanks > 1, "");
5555
num_bytes = num_elems * sizeof(dtype_t);
5656

57-
int per_channel_bytes = num_bytes * num_ranks;
57+
int64_t per_channel_bytes = num_bytes * num_ranks;
5858
total_bytes = per_channel_bytes * num_sms;
5959
for (int i = 0; i < kNumRanks; ++i) {
60-
ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
61-
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
60+
ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
61+
gbl_ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
6262
}
6363
}
6464

@@ -69,14 +69,14 @@ public:
6969
}
7070

7171
__device__ __forceinline__ AsymBuffer advance_also(void*& gbl_ptr) {
72-
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
72+
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
7373
return *this;
7474
}
7575

7676
template <int kNumAlsoRanks>
7777
__device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) {
7878
for (int i = 0; i < kNumAlsoRanks; ++i)
79-
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
79+
gbl_ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
8080
return *this;
8181
}
8282

@@ -97,19 +97,19 @@ private:
9797
// NOTES: for non-decoupled case, `recv_ptr` is not used
9898
uint8_t* send_ptr;
9999
uint8_t* recv_ptr;
100-
int num_bytes;
100+
int64_t num_bytes;
101101

102102
public:
103-
int total_bytes;
103+
int64_t total_bytes;
104104

105105
__device__ __forceinline__ SymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) {
106106
num_bytes = num_elems * sizeof(dtype_t);
107107

108-
int per_channel_bytes = num_bytes * num_ranks;
108+
int64_t per_channel_bytes = num_bytes * num_ranks;
109109
total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
110-
send_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
111-
recv_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
112-
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
110+
send_ptr = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
111+
recv_ptr = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
112+
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
113113
}
114114

115115
__device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {

0 commit comments

Comments
 (0)