Skip to content

Commit f721096

Browse files
authored
[BugFix] Some fixes for custom allreduce kernels (#2760)
1 parent e90fc21 commit f721096

File tree

6 files changed

+232
-250
lines changed

6 files changed

+232
-250
lines changed

csrc/custom_all_reduce.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
2929
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
3030
}
3131
return (fptr_t) new vllm::CustomAllreduce(
32-
reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
32+
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
3333
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
3434
}
3535

@@ -62,9 +62,9 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
6262
if (inp_size % 16 != 0) return false;
6363
if (!_is_weak_contiguous(inp)) return false;
6464
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
65-
// 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
66-
// <= 512k
67-
return world_size <= 4 && inp_size <= 512 * 1024;
65+
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
66+
// performance improvement over NCCL.
67+
return false;
6868
}
6969

7070
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
@@ -126,7 +126,7 @@ void dispose(fptr_t _fa) {
126126
delete fa;
127127
}
128128

129-
int meta_size() { return sizeof(vllm::Metadata); }
129+
int meta_size() { return sizeof(vllm::Signal); }
130130

131131
void register_buffer(fptr_t _fa, torch::Tensor &t,
132132
const std::vector<std::string> &handles,

csrc/custom_all_reduce.cuh

Lines changed: 75 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,17 @@
2323

2424
namespace vllm {
2525

26+
constexpr int kMaxBlocks = 64;
27+
// note: we don't want to use atomics for signals because peer atomics are no
28+
// supported on PCIe links
2629
struct Signal {
27-
alignas(64) union {
28-
uint64_t flag;
29-
unsigned char data[8];
30-
} start;
31-
alignas(64) union {
32-
uint64_t flag;
33-
unsigned char data[8];
34-
} end;
30+
alignas(128) uint32_t start[kMaxBlocks][8];
31+
alignas(128) uint32_t end[kMaxBlocks][8];
3532
};
3633

37-
struct Metadata {
38-
alignas(128) Signal sg;
39-
alignas(128) int counter;
40-
};
41-
static_assert(offsetof(Metadata, counter) == 128);
42-
static_assert(sizeof(Metadata) == 256);
43-
4434
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
4535

46-
struct RankSignals {
47-
volatile Signal *signals[8];
48-
};
36+
struct __align__(16) RankSignals { volatile Signal *signals[8]; };
4937

5038
// like std::array, but aligned
5139
template <typename T, int sz>
@@ -135,70 +123,49 @@ DINLINE O downcast(array_t<float, O::size> val) {
135123
}
136124
}
137125

138-
// compute flag at compile time
139-
__host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
140-
auto m = std::numeric_limits<uint64_t>::max();
141-
return m >> ((8 - ngpus) * 8);
142-
}
143-
126+
// This function is meant to be used as the first synchronization in the all
127+
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
128+
// prior memory accesses. Note: volatile writes will not be reordered against
129+
// other volatile writes.
144130
template <int ngpus>
145-
DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
131+
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
146132
int rank) {
147-
constexpr auto FLAG = compute_flag(ngpus);
148-
if (blockIdx.x == 0) {
149-
if (threadIdx.x < ngpus)
150-
// simultaneously write to the corresponding byte to all other ranks.
151-
// Latency = 1 p2p write
152-
sg.signals[threadIdx.x]->start.data[rank] = 255;
153-
else if (threadIdx.x == 32)
154-
// reset
155-
meta->sg.end.flag = 0;
156-
}
157-
if (threadIdx.x == 0) {
158-
while (meta->sg.start.flag != FLAG)
133+
if (threadIdx.x < ngpus) {
134+
// reset flag for next time
135+
self_sg->end[blockIdx.x][threadIdx.x] = 0;
136+
// simultaneously write to the corresponding flag of all ranks.
137+
// Latency = 1 p2p write
138+
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
139+
// wait until we got true from all ranks
140+
while (!self_sg->start[blockIdx.x][threadIdx.x])
159141
;
160142
}
161143
__syncthreads();
162144
}
163145

146+
// This function is meant to be used as the second or the final synchronization
147+
// barrier in the all reduce kernel. If it's the final synchronization barrier,
148+
// we don't need to make any visibility guarantees for prior memory accesses.
164149
template <int ngpus, bool final_sync = false>
165-
DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
150+
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
166151
int rank) {
167-
constexpr auto FLAG = compute_flag(ngpus);
168152
__syncthreads();
169-
__shared__ int num;
170-
if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
171-
__syncthreads();
172-
173-
// Only the last completing block can perform the end synchronization
174-
// This can ensures when the final busy wait ends, all ranks must have
175-
// finished reading each other's buffer.
176-
if (num == gridDim.x - 1) {
177-
if (threadIdx.x == 32) {
178-
// reset in a different warp
179-
meta->counter = 0;
180-
meta->sg.start.flag = 0;
181-
} else if (threadIdx.x < ngpus) {
182-
// simultaneously write to the corresponding byte to all other ranks.
183-
// Latency = 1 p2p write
184-
sg.signals[threadIdx.x]->end.data[rank] = 255;
185-
}
186-
// if this is the final sync, only one block needs it
187-
// because kernel exit can serve as sync
188-
if constexpr (final_sync) {
189-
if (threadIdx.x == 0) {
190-
while (meta->sg.end.flag != FLAG)
191-
;
192-
}
193-
}
194-
}
195-
if constexpr (!final_sync) {
196-
if (threadIdx.x == 0) {
197-
while (meta->sg.end.flag != FLAG)
198-
;
199-
}
200-
__syncthreads();
153+
// eliminate the case that prior writes are not visible after signals become
154+
// visible. Note that I did not managed to make this happen through a lot of
155+
// testing. Might be the case that hardware provides stronger guarantee than
156+
// the memory model.
157+
if constexpr (!final_sync) __threadfence_system();
158+
if (threadIdx.x < ngpus) {
159+
// reset flag for next time
160+
self_sg->start[blockIdx.x][threadIdx.x] = 0;
161+
// simultaneously write to the corresponding flag of all ranks.
162+
// Latency = 1 p2p write
163+
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
164+
// wait until we got true from all ranks
165+
while (!self_sg->end[blockIdx.x][threadIdx.x])
166+
;
201167
}
168+
if constexpr (!final_sync) __syncthreads();
202169
}
203170

204171
template <typename P, int ngpus, typename A>
@@ -214,32 +181,32 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
214181
template <typename T, int ngpus>
215182
__global__ void __launch_bounds__(512, 1)
216183
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
217-
volatile Metadata *meta, T *__restrict__ result,
184+
volatile Signal *self_sg, T *__restrict__ result,
218185
int rank, int size) {
219186
using P = typename packed_t<T>::P;
220187
using A = typename packed_t<T>::A;
221188
// note: we don't reorder the address so the accumulation order is the same
222189
// for all ranks, ensuring bitwise identical results
223190
auto dp = *_dp;
224-
start_sync<ngpus>(sg, meta, rank);
191+
start_sync<ngpus>(sg, self_sg, rank);
225192
// do the actual reduction
226193
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
227194
idx += gridDim.x * blockDim.x) {
228195
((P *)result)[idx] =
229196
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
230197
}
231-
end_sync<ngpus, true>(sg, meta, rank);
198+
end_sync<ngpus, true>(sg, self_sg, rank);
232199
}
233200

234201
template <typename P>
235202
DINLINE P *get_tmp_buf(volatile Signal *sg) {
236-
return (P *)(((Metadata *)sg) + 1);
203+
return (P *)(((Signal *)sg) + 1);
237204
}
238205

239206
template <typename T, int ngpus>
240207
__global__ void __launch_bounds__(512, 1)
241208
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
242-
volatile Metadata *meta, T *__restrict__ result,
209+
volatile Signal *self_sg, T *__restrict__ result,
243210
int rank, int size) {
244211
int tid = blockIdx.x * blockDim.x + threadIdx.x;
245212
int stride = gridDim.x * blockDim.x;
@@ -248,6 +215,7 @@ __global__ void __launch_bounds__(512, 1)
248215
int part = size / ngpus;
249216
int start = rank * part;
250217
int end = rank == ngpus - 1 ? size : start + part;
218+
int largest_part = part + size % ngpus;
251219
const P *ptrs[ngpus];
252220
P *tmps[ngpus];
253221
#pragma unroll
@@ -257,75 +225,28 @@ __global__ void __launch_bounds__(512, 1)
257225
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
258226
}
259227
auto tmp_out = tmps[0];
260-
start_sync<ngpus>(sg, meta, rank);
228+
start_sync<ngpus>(sg, self_sg, rank);
261229
// stage 1: reduce scatter
262230
for (int idx = start + tid; idx < end; idx += stride) {
263231
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
264232
}
265-
// Maybe TODO: replace this with per-block release-acquire
266-
// can save about 1-2us (not a lot though)
267-
end_sync<ngpus>(sg, meta, rank);
268-
269-
// stage 2: allgather
270-
for (int idx = tid; idx < part; idx += stride) {
233+
end_sync<ngpus>(sg, self_sg, rank);
234+
235+
// stage 2: allgather. Note: it's important to match the tid between
236+
// the two stages, because visibility across devices is only guaranteed
237+
// between threads that have the same tid. If thread i computes the sum of
238+
// start + i in the first stage, then thread i also gathers start + i from all
239+
// ranks.
240+
for (int idx = tid; idx < largest_part; idx += stride) {
271241
#pragma unroll
272242
for (int i = 0; i < ngpus; i++) {
273-
int dst_idx = ((rank + i) % ngpus) * part + idx;
274-
((P *)result)[dst_idx] = tmps[i][idx];
275-
}
276-
}
277-
// process the last larger partition
278-
int remaining = size - part * ngpus;
279-
if (tid < remaining) {
280-
int dst_idx = tid + part * ngpus;
281-
((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
282-
}
283-
284-
// faster than this
285-
// for (int idx = tid; idx < size; idx += stride) {
286-
// int target_rank = idx / part;
287-
// if (target_rank == ngpus) target_rank -= 1;
288-
// ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
289-
// }
290-
}
291-
292-
template <typename T, int ngpus>
293-
__global__ void __launch_bounds__(512, 1)
294-
cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
295-
volatile Metadata *meta,
296-
T *__restrict__ result, int rank,
297-
int size) {
298-
int tid = blockIdx.x * blockDim.x + threadIdx.x;
299-
int stride = gridDim.x * blockDim.x;
300-
using P = typename packed_t<T>::P;
301-
using A = typename packed_t<T>::A;
302-
auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
303-
constexpr int hg = ngpus / 2;
304-
// Actually not quite half butterfly.
305-
// This is an all-to-all within each group containing half of the ranks
306-
// followed by cross-group add. Equivalent to half butterfly when there
307-
// are 4 GPUs, a common case for PCIe cards like T4 and A10.
308-
const P *ptrs[hg];
309-
{
310-
int start = rank - rank % hg;
311-
#pragma unroll
312-
for (int i = 0; i < hg; i++) {
313-
ptrs[i] = (const P *)_dp->ptrs[i + start];
243+
int gather_from_rank = ((rank + i) % ngpus);
244+
if (gather_from_rank == ngpus - 1 || idx < part) {
245+
int dst_idx = gather_from_rank * part + idx;
246+
((P *)result)[dst_idx] = tmps[i][idx];
247+
}
314248
}
315249
}
316-
start_sync<ngpus>(sg, meta, rank);
317-
for (int idx = tid; idx < size; idx += stride) {
318-
tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
319-
}
320-
end_sync<ngpus>(sg, meta, rank);
321-
322-
auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
323-
// do the cross group reduction
324-
for (int idx = tid; idx < size; idx += stride) {
325-
auto tmp = tmp_out[idx];
326-
packed_assign_add(tmp, src[idx]);
327-
((P *)result)[idx] = tmp;
328-
}
329250
}
330251

331252
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
@@ -341,7 +262,7 @@ class CustomAllreduce {
341262
// below are device pointers
342263
RankSignals sg_;
343264
std::unordered_map<void *, RankData *> buffers_;
344-
Metadata *meta_;
265+
Signal *self_sg_;
345266

346267
// stores the registered device pointers from all ranks
347268
RankData *d_rank_data_base_, *d_rank_data_end_;
@@ -352,32 +273,32 @@ class CustomAllreduce {
352273
/**
353274
* meta is a pointer to device metadata and temporary buffer for allreduce.
354275
*
355-
* There's a total of sizeof(Metadata) of prefix before the actual data,
276+
* There's a total of sizeof(Signal) of prefix before the actual data,
356277
* so meta + 1 points to actual temporary buffer.
357278
*
358279
* note: this class does not own any device memory. Any required buffers
359280
* are passed in from the constructor
360281
*/
361-
CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
282+
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz,
362283
const cudaIpcMemHandle_t *handles,
363284
const std::vector<int64_t> &offsets, int rank,
364285
bool full_nvlink = true)
365286
: rank_(rank),
366287
world_size_(offsets.size()),
367288
full_nvlink_(full_nvlink),
368-
meta_(meta),
289+
self_sg_(meta),
369290
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
370291
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
371292
for (int i = 0; i < world_size_; i++) {
372-
Metadata *rank_meta;
293+
Signal *rank_sg;
373294
if (i != rank_) {
374295
char *handle = open_ipc_handle(&handles[i]);
375296
handle += offsets[i];
376-
rank_meta = (Metadata *)handle;
297+
rank_sg = (Signal *)handle;
377298
} else {
378-
rank_meta = meta_;
299+
rank_sg = self_sg_;
379300
}
380-
sg_.signals[i] = &rank_meta->sg;
301+
sg_.signals[i] = rank_sg;
381302
}
382303
}
383304

@@ -492,6 +413,10 @@ class CustomAllreduce {
492413
"custom allreduce currently requires input length to be multiple "
493414
"of " +
494415
std::to_string(d));
416+
if (block_limit > kMaxBlocks)
417+
throw std::runtime_error("max supported block limit is " +
418+
std::to_string(kMaxBlocks) + ". Got " +
419+
std::to_string(block_limit));
495420

496421
RankData *ptrs;
497422
cudaStreamCaptureStatus status;
@@ -512,9 +437,9 @@ class CustomAllreduce {
512437
size /= d;
513438
auto bytes = size * sizeof(typename packed_t<T>::P);
514439
int blocks = std::min(block_limit, (size + threads - 1) / threads);
515-
#define KL(ngpus, name) \
516-
name<T, ngpus> \
517-
<<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
440+
#define KL(ngpus, name) \
441+
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
442+
rank_, size);
518443
#define REDUCE_CASE(ngpus) \
519444
case ngpus: { \
520445
if (world_size_ == 2) { \
@@ -526,8 +451,6 @@ class CustomAllreduce {
526451
} else { \
527452
KL(ngpus, cross_device_reduce_2stage); \
528453
} \
529-
} else { \
530-
KL(ngpus, cross_device_reduce_half_butterfly); \
531454
} \
532455
break; \
533456
}
@@ -556,7 +479,7 @@ class CustomAllreduce {
556479
/**
557480
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
558481
a template instantiation:
559-
* template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
560-
int, int, int);
482+
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
483+
half *, int, int, int);
561484
*/
562485
} // namespace vllm

0 commit comments

Comments
 (0)