23
23
24
24
namespace vllm {
25
25
26
+ constexpr int kMaxBlocks = 64 ;
27
+ // note: we don't want to use atomics for signals because peer atomics are no
28
+ // supported on PCIe links
26
29
struct Signal {
27
- alignas (64 ) union {
28
- uint64_t flag;
29
- unsigned char data[8 ];
30
- } start;
31
- alignas (64 ) union {
32
- uint64_t flag;
33
- unsigned char data[8 ];
34
- } end;
30
+ alignas (128 ) uint32_t start[kMaxBlocks ][8 ];
31
+ alignas (128 ) uint32_t end[kMaxBlocks ][8 ];
35
32
};
36
33
37
- struct Metadata {
38
- alignas (128 ) Signal sg;
39
- alignas (128 ) int counter;
40
- };
41
- static_assert (offsetof(Metadata, counter) == 128 );
42
- static_assert (sizeof (Metadata) == 256 );
43
-
44
34
struct __align__ (16 ) RankData { const void *__restrict__ ptrs[8 ]; };
45
35
46
- struct RankSignals {
47
- volatile Signal *signals[8 ];
48
- };
36
+ struct __align__ (16 ) RankSignals { volatile Signal *signals[8 ]; };
49
37
50
38
// like std::array, but aligned
51
39
template <typename T, int sz>
@@ -135,70 +123,49 @@ DINLINE O downcast(array_t<float, O::size> val) {
135
123
}
136
124
}
137
125
138
- // compute flag at compile time
139
- __host__ __device__ constexpr uint64_t compute_flag (int ngpus) {
140
- auto m = std::numeric_limits<uint64_t >::max ();
141
- return m >> ((8 - ngpus) * 8 );
142
- }
143
-
126
+ // This function is meant to be used as the first synchronization in the all
127
+ // reduce kernel. Thus, it doesn't need to make any visibility guarantees for
128
+ // prior memory accesses. Note: volatile writes will not be reordered against
129
+ // other volatile writes.
144
130
template <int ngpus>
145
- DINLINE void start_sync (const RankSignals &sg, volatile Metadata *meta ,
131
+ DINLINE void start_sync (const RankSignals &sg, volatile Signal *self_sg ,
146
132
int rank) {
147
- constexpr auto FLAG = compute_flag (ngpus);
148
- if (blockIdx .x == 0 ) {
149
- if (threadIdx .x < ngpus)
150
- // simultaneously write to the corresponding byte to all other ranks.
151
- // Latency = 1 p2p write
152
- sg.signals [threadIdx .x ]->start .data [rank] = 255 ;
153
- else if (threadIdx .x == 32 )
154
- // reset
155
- meta->sg .end .flag = 0 ;
156
- }
157
- if (threadIdx .x == 0 ) {
158
- while (meta->sg .start .flag != FLAG)
133
+ if (threadIdx .x < ngpus) {
134
+ // reset flag for next time
135
+ self_sg->end [blockIdx .x ][threadIdx .x ] = 0 ;
136
+ // simultaneously write to the corresponding flag of all ranks.
137
+ // Latency = 1 p2p write
138
+ sg.signals [threadIdx .x ]->start [blockIdx .x ][rank] = 1 ;
139
+ // wait until we got true from all ranks
140
+ while (!self_sg->start [blockIdx .x ][threadIdx .x ])
159
141
;
160
142
}
161
143
__syncthreads ();
162
144
}
163
145
146
+ // This function is meant to be used as the second or the final synchronization
147
+ // barrier in the all reduce kernel. If it's the final synchronization barrier,
148
+ // we don't need to make any visibility guarantees for prior memory accesses.
164
149
template <int ngpus, bool final_sync = false >
165
- DINLINE void end_sync (const RankSignals &sg, volatile Metadata *meta ,
150
+ DINLINE void end_sync (const RankSignals &sg, volatile Signal *self_sg ,
166
151
int rank) {
167
- constexpr auto FLAG = compute_flag (ngpus);
168
152
__syncthreads ();
169
- __shared__ int num;
170
- if (threadIdx .x == 0 ) num = atomicAdd ((int *)&meta->counter , 1 );
171
- __syncthreads ();
172
-
173
- // Only the last completing block can perform the end synchronization
174
- // This can ensures when the final busy wait ends, all ranks must have
175
- // finished reading each other's buffer.
176
- if (num == gridDim .x - 1 ) {
177
- if (threadIdx .x == 32 ) {
178
- // reset in a different warp
179
- meta->counter = 0 ;
180
- meta->sg .start .flag = 0 ;
181
- } else if (threadIdx .x < ngpus) {
182
- // simultaneously write to the corresponding byte to all other ranks.
183
- // Latency = 1 p2p write
184
- sg.signals [threadIdx .x ]->end .data [rank] = 255 ;
185
- }
186
- // if this is the final sync, only one block needs it
187
- // because kernel exit can serve as sync
188
- if constexpr (final_sync) {
189
- if (threadIdx .x == 0 ) {
190
- while (meta->sg .end .flag != FLAG)
191
- ;
192
- }
193
- }
194
- }
195
- if constexpr (!final_sync) {
196
- if (threadIdx .x == 0 ) {
197
- while (meta->sg .end .flag != FLAG)
198
- ;
199
- }
200
- __syncthreads ();
153
+ // eliminate the case that prior writes are not visible after signals become
154
+ // visible. Note that I did not managed to make this happen through a lot of
155
+ // testing. Might be the case that hardware provides stronger guarantee than
156
+ // the memory model.
157
+ if constexpr (!final_sync) __threadfence_system ();
158
+ if (threadIdx .x < ngpus) {
159
+ // reset flag for next time
160
+ self_sg->start [blockIdx .x ][threadIdx .x ] = 0 ;
161
+ // simultaneously write to the corresponding flag of all ranks.
162
+ // Latency = 1 p2p write
163
+ sg.signals [threadIdx .x ]->end [blockIdx .x ][rank] = 1 ;
164
+ // wait until we got true from all ranks
165
+ while (!self_sg->end [blockIdx .x ][threadIdx .x ])
166
+ ;
201
167
}
168
+ if constexpr (!final_sync) __syncthreads ();
202
169
}
203
170
204
171
template <typename P, int ngpus, typename A>
@@ -214,32 +181,32 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
214
181
template <typename T, int ngpus>
215
182
__global__ void __launch_bounds__ (512 , 1 )
216
183
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
217
- volatile Metadata *meta , T *__restrict__ result,
184
+ volatile Signal *self_sg , T *__restrict__ result,
218
185
int rank, int size) {
219
186
using P = typename packed_t <T>::P;
220
187
using A = typename packed_t <T>::A;
221
188
// note: we don't reorder the address so the accumulation order is the same
222
189
// for all ranks, ensuring bitwise identical results
223
190
auto dp = *_dp;
224
- start_sync<ngpus>(sg, meta , rank);
191
+ start_sync<ngpus>(sg, self_sg , rank);
225
192
// do the actual reduction
226
193
for (int idx = blockIdx .x * blockDim .x + threadIdx .x ; idx < size;
227
194
idx += gridDim .x * blockDim .x ) {
228
195
((P *)result)[idx] =
229
196
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs [0 ], idx);
230
197
}
231
- end_sync<ngpus, true >(sg, meta , rank);
198
+ end_sync<ngpus, true >(sg, self_sg , rank);
232
199
}
233
200
234
201
template <typename P>
235
202
DINLINE P *get_tmp_buf (volatile Signal *sg) {
236
- return (P *)(((Metadata *)sg) + 1 );
203
+ return (P *)(((Signal *)sg) + 1 );
237
204
}
238
205
239
206
template <typename T, int ngpus>
240
207
__global__ void __launch_bounds__ (512 , 1 )
241
208
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
242
- volatile Metadata *meta , T *__restrict__ result,
209
+ volatile Signal *self_sg , T *__restrict__ result,
243
210
int rank, int size) {
244
211
int tid = blockIdx .x * blockDim .x + threadIdx .x ;
245
212
int stride = gridDim .x * blockDim .x ;
@@ -248,6 +215,7 @@ __global__ void __launch_bounds__(512, 1)
248
215
int part = size / ngpus;
249
216
int start = rank * part;
250
217
int end = rank == ngpus - 1 ? size : start + part;
218
+ int largest_part = part + size % ngpus;
251
219
const P *ptrs[ngpus];
252
220
P *tmps[ngpus];
253
221
#pragma unroll
@@ -257,75 +225,28 @@ __global__ void __launch_bounds__(512, 1)
257
225
tmps[i] = get_tmp_buf<P>(sg.signals [target]);
258
226
}
259
227
auto tmp_out = tmps[0 ];
260
- start_sync<ngpus>(sg, meta , rank);
228
+ start_sync<ngpus>(sg, self_sg , rank);
261
229
// stage 1: reduce scatter
262
230
for (int idx = start + tid; idx < end; idx += stride) {
263
231
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
264
232
}
265
- // Maybe TODO: replace this with per-block release-acquire
266
- // can save about 1-2us (not a lot though)
267
- end_sync<ngpus>(sg, meta, rank);
268
-
269
- // stage 2: allgather
270
- for (int idx = tid; idx < part; idx += stride) {
233
+ end_sync<ngpus>(sg, self_sg, rank);
234
+
235
+ // stage 2: allgather. Note: it's important to match the tid between
236
+ // the two stages, because visibility across devices is only guaranteed
237
+ // between threads that have the same tid. If thread i computes the sum of
238
+ // start + i in the first stage, then thread i also gathers start + i from all
239
+ // ranks.
240
+ for (int idx = tid; idx < largest_part; idx += stride) {
271
241
#pragma unroll
272
242
for (int i = 0 ; i < ngpus; i++) {
273
- int dst_idx = ((rank + i) % ngpus) * part + idx;
274
- ((P *)result)[dst_idx] = tmps[i][idx];
275
- }
276
- }
277
- // process the last larger partition
278
- int remaining = size - part * ngpus;
279
- if (tid < remaining) {
280
- int dst_idx = tid + part * ngpus;
281
- ((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals [ngpus - 1 ])[part + tid];
282
- }
283
-
284
- // faster than this
285
- // for (int idx = tid; idx < size; idx += stride) {
286
- // int target_rank = idx / part;
287
- // if (target_rank == ngpus) target_rank -= 1;
288
- // ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
289
- // }
290
- }
291
-
292
- template <typename T, int ngpus>
293
- __global__ void __launch_bounds__ (512 , 1 )
294
- cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
295
- volatile Metadata *meta,
296
- T *__restrict__ result, int rank,
297
- int size) {
298
- int tid = blockIdx .x * blockDim .x + threadIdx .x ;
299
- int stride = gridDim .x * blockDim .x ;
300
- using P = typename packed_t <T>::P;
301
- using A = typename packed_t <T>::A;
302
- auto tmp_out = get_tmp_buf<P>(sg.signals [rank]);
303
- constexpr int hg = ngpus / 2 ;
304
- // Actually not quite half butterfly.
305
- // This is an all-to-all within each group containing half of the ranks
306
- // followed by cross-group add. Equivalent to half butterfly when there
307
- // are 4 GPUs, a common case for PCIe cards like T4 and A10.
308
- const P *ptrs[hg];
309
- {
310
- int start = rank - rank % hg;
311
- #pragma unroll
312
- for (int i = 0 ; i < hg; i++) {
313
- ptrs[i] = (const P *)_dp->ptrs [i + start];
243
+ int gather_from_rank = ((rank + i) % ngpus);
244
+ if (gather_from_rank == ngpus - 1 || idx < part) {
245
+ int dst_idx = gather_from_rank * part + idx;
246
+ ((P *)result)[dst_idx] = tmps[i][idx];
247
+ }
314
248
}
315
249
}
316
- start_sync<ngpus>(sg, meta, rank);
317
- for (int idx = tid; idx < size; idx += stride) {
318
- tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
319
- }
320
- end_sync<ngpus>(sg, meta, rank);
321
-
322
- auto src = get_tmp_buf<P>(sg.signals [(ngpus - 1 ) - rank % ngpus]);
323
- // do the cross group reduction
324
- for (int idx = tid; idx < size; idx += stride) {
325
- auto tmp = tmp_out[idx];
326
- packed_assign_add (tmp, src[idx]);
327
- ((P *)result)[idx] = tmp;
328
- }
329
250
}
330
251
331
252
using IPC_KEY = std::array<uint8_t , sizeof (cudaIpcMemHandle_t)>;
@@ -341,7 +262,7 @@ class CustomAllreduce {
341
262
// below are device pointers
342
263
RankSignals sg_;
343
264
std::unordered_map<void *, RankData *> buffers_;
344
- Metadata *meta_ ;
265
+ Signal *self_sg_ ;
345
266
346
267
// stores the registered device pointers from all ranks
347
268
RankData *d_rank_data_base_, *d_rank_data_end_;
@@ -352,32 +273,32 @@ class CustomAllreduce {
352
273
/* *
353
274
* meta is a pointer to device metadata and temporary buffer for allreduce.
354
275
*
355
- * There's a total of sizeof(Metadata ) of prefix before the actual data,
276
+ * There's a total of sizeof(Signal ) of prefix before the actual data,
356
277
* so meta + 1 points to actual temporary buffer.
357
278
*
358
279
* note: this class does not own any device memory. Any required buffers
359
280
* are passed in from the constructor
360
281
*/
361
- CustomAllreduce (Metadata *meta, void *rank_data, size_t rank_data_sz,
282
+ CustomAllreduce (Signal *meta, void *rank_data, size_t rank_data_sz,
362
283
const cudaIpcMemHandle_t *handles,
363
284
const std::vector<int64_t > &offsets, int rank,
364
285
bool full_nvlink = true )
365
286
: rank_(rank),
366
287
world_size_ (offsets.size()),
367
288
full_nvlink_(full_nvlink),
368
- meta_ (meta),
289
+ self_sg_ (meta),
369
290
d_rank_data_base_(reinterpret_cast <RankData *>(rank_data)),
370
291
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof (RankData)) {
371
292
for (int i = 0 ; i < world_size_; i++) {
372
- Metadata *rank_meta ;
293
+ Signal *rank_sg ;
373
294
if (i != rank_) {
374
295
char *handle = open_ipc_handle (&handles[i]);
375
296
handle += offsets[i];
376
- rank_meta = (Metadata *)handle;
297
+ rank_sg = (Signal *)handle;
377
298
} else {
378
- rank_meta = meta_ ;
299
+ rank_sg = self_sg_ ;
379
300
}
380
- sg_.signals [i] = &rank_meta-> sg ;
301
+ sg_.signals [i] = rank_sg ;
381
302
}
382
303
}
383
304
@@ -492,6 +413,10 @@ class CustomAllreduce {
492
413
" custom allreduce currently requires input length to be multiple "
493
414
" of " +
494
415
std::to_string (d));
416
+ if (block_limit > kMaxBlocks )
417
+ throw std::runtime_error (" max supported block limit is " +
418
+ std::to_string (kMaxBlocks ) + " . Got " +
419
+ std::to_string (block_limit));
495
420
496
421
RankData *ptrs;
497
422
cudaStreamCaptureStatus status;
@@ -512,9 +437,9 @@ class CustomAllreduce {
512
437
size /= d;
513
438
auto bytes = size * sizeof (typename packed_t <T>::P);
514
439
int blocks = std::min (block_limit, (size + threads - 1 ) / threads);
515
- #define KL (ngpus, name ) \
516
- name<T, ngpus> \
517
- <<<blocks, threads, 0 , stream>>> (ptrs, sg_, meta_, output, rank_, size);
440
+ #define KL (ngpus, name ) \
441
+ name<T, ngpus><<<blocks, threads, 0 , stream>>> (ptrs, sg_, self_sg_, output, \
442
+ rank_, size);
518
443
#define REDUCE_CASE (ngpus ) \
519
444
case ngpus: { \
520
445
if (world_size_ == 2 ) { \
@@ -526,8 +451,6 @@ class CustomAllreduce {
526
451
} else { \
527
452
KL (ngpus, cross_device_reduce_2stage); \
528
453
} \
529
- } else { \
530
- KL (ngpus, cross_device_reduce_half_butterfly); \
531
454
} \
532
455
break ; \
533
456
}
@@ -556,7 +479,7 @@ class CustomAllreduce {
556
479
/* *
557
480
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
558
481
a template instantiation:
559
- * template void CustomAllreduce::allreduce<half>(cudaStream_t, half * , half *,
560
- int, int, int);
482
+ * template void vllm:: CustomAllreduce::allreduce<half>(cudaStream_t, half *,
483
+ half *, int, int, int);
561
484
*/
562
485
} // namespace vllm
0 commit comments