Skip to content

Commit 6cb5d3a

Browse files
author
Suffian Khan
authored
Fix multi-tensor LAMB reduction to be deterministic (#6028)
* define ordering of reduction across blocks * save state * remove debug code * remove debug code * review comments * significant correction for reduction only over blocks on same tensor * addressing ocmments * update rocm/lamb.cc to build as well * remove times 2048*size in multitensor test until threshold error in rocm resolved * convert tuple => struct as per recomendation * update comment * apply perfect forwarding for launch_multitensor to permit passing ref rather than pointer * remove excess template arguments from rocm lamb.cc launch_multitensor as well * fixes for AMD build * pr comments * run formatter from vscode * formatter on cuda files
1 parent c8ac34d commit 6cb5d3a

File tree

9 files changed

+236
-104
lines changed

9 files changed

+236
-104
lines changed

onnxruntime/core/providers/cuda/multi_tensor/common.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ void launch_multi_tensor_functor(
7777
std::vector<int>& tensor_sizes,
7878
std::vector<std::vector<void*>>& grouped_tensor_pointers,
7979
TMultiTensorFunctor multipleTensorKernel,
80-
TFunctorParams... kernelParams) {
80+
TFunctorParams&&... kernelParams) {
8181
ORT_ENFORCE(tensor_sizes.size() > 0);
8282
ORT_ENFORCE(tensor_sizes.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
8383
ORT_ENFORCE(grouped_tensor_pointers.size() > 0);
@@ -121,15 +121,15 @@ void launch_multi_tensor_functor(
121121
chunk_group.chunk_count = block_index;
122122

123123
if (block_index == chunk_group.max_block_count) {
124-
multipleTensorKernel(chunk_group, kernelParams...);
124+
multipleTensorKernel(chunk_group, std::forward<TFunctorParams>(kernelParams)...);
125125
block_index = 0;
126126
}
127127
}
128128

129129
// After ++tensor_group_index, tensor_group_index becomes the count of tensor group in chunk_group.
130130
++tensor_group_index;
131131
if (tensor_group_index == chunk_group.max_tensor_group_count) {
132-
multipleTensorKernel(chunk_group, kernelParams...);
132+
multipleTensorKernel(chunk_group, std::forward<TFunctorParams>(kernelParams)...);
133133
block_index = 0;
134134
tensor_group_index = 0;
135135
}
@@ -138,7 +138,7 @@ void launch_multi_tensor_functor(
138138
// This round of processing tensor group is finished.
139139
// All the groups remain in chunk group should be processed right now.
140140
if (block_index != 0) {
141-
multipleTensorKernel(chunk_group, kernelParams...);
141+
multipleTensorKernel(chunk_group, std::forward<TFunctorParams>(kernelParams)...);
142142
block_index = 0;
143143
tensor_group_index = 0;
144144
}

orttraining/orttraining/training_ops/cuda/math/isfinite.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ Status IsAllFiniteOp<TSrc>::ComputeInternal(OpKernelContext* context) const {
7575

7676
// Check if all values are finite and write true to output.
7777
// Otherwise, false will be written.
78-
launch_multi_tensor_functor<1, TFunctor, bool*>(
78+
launch_multi_tensor_functor<1, TFunctor>(
7979
2048 * 32, tensor_sizes, grouped_tensor_pointers, functor, output_data);
8080

8181
return Status::OK();

orttraining/orttraining/training_ops/cuda/optimizer/lamb.cc

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ Status launch_lamb_compute_direction(
254254
typedef LambMultiTensorComputeDirectionFunctor<CudaT2, CudaT3, CudaT4, CudaT_GRAD_NORM> LambStage1;
255255
LambStage1 lamb_stage1;
256256

257-
launch_multi_tensor_functor<tensor_count_per_group, LambStage1, const CudaT2*, const CudaT_GRAD_NORM*, float, float, float, float>(
257+
launch_multi_tensor_functor<tensor_count_per_group, LambStage1>(
258258
2048 * 32,
259259
tensor_sizes_in_buckets[key],
260260
buckets[key],
@@ -267,6 +267,7 @@ Status launch_lamb_compute_direction(
267267

268268
template <typename CudaTNorm, typename CudaTIn1, typename CudaTIn2>
269269
Status launch_lamb_reduction(
270+
const CudaKernel& kernel,
270271
const int group_count,
271272
std::vector<int>& tensor_sizes,
272273
std::vector<CudaTNorm*>& p_w_norms,
@@ -332,7 +333,10 @@ Status launch_lamb_reduction(
332333
2048 * 32,
333334
tensor_sizes_in_buckets,
334335
buckets,
335-
reducer);
336+
reducer,
337+
kernel,
338+
reduction_buffer,
339+
reduction_buffer_size);
336340
}
337341

338342
return Status::OK();
@@ -412,9 +416,7 @@ Status launch_lamb_update(
412416
LambStage2;
413417
LambStage2 lamb_stage2;
414418

415-
launch_multi_tensor_functor<
416-
tensor_count_per_group, LambStage2,
417-
const CudaT1*, const float, const float>(
419+
launch_multi_tensor_functor<tensor_count_per_group, LambStage2>(
418420
2048 * 32,
419421
tensor_sizes_in_bucket,
420422
buckets,
@@ -542,9 +544,18 @@ Status LambOptimizer<T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP>::Compute
542544
max_tensor_size = std::max(max_tensor_size, static_cast<int>(w.Shape().Size()));
543545
}
544546

545-
// Allocate a buffer in byte for reduction API calls.
546-
const auto reduction_buffer_size =
547-
compute_reduction_buffer_size<CudaT2>(max_tensor_size);
547+
const size_t reduction_buffer_size = [&]() {
548+
// Allocate a buffer in byte for reduction API calls.
549+
size_t rbs = compute_reduction_buffer_size<CudaT2>(max_tensor_size);
550+
551+
// Enlarge reduction buffer to accomodate multi-tensor reduction kernel as well
552+
const int tensor_group_size = 4; // w, d, w_norm, d_norm
553+
const int max_blocks = ChunkGroup<tensor_group_size>::max_block_count;
554+
const size_t multitensor_block_reduce_buffer_size = 2 * max_blocks * sizeof(CudaT2);
555+
rbs = std::max(rbs, multitensor_block_reduce_buffer_size);
556+
557+
return rbs;
558+
}();
548559

549560
// Allocate reduction buffer whose size is reduction_buffer_size bytes.
550561
IAllocatorUniquePtr<void> reduction_buffer = GetScratchBuffer<void>(reduction_buffer_size);
@@ -640,6 +651,7 @@ Status LambOptimizer<T1, T2, T3, T4, T_GRAD_NORM, T_MIXED_PRECISION_FP>::Compute
640651
do_bias_correction_));
641652

642653
ORT_RETURN_IF_ERROR(launch_lamb_reduction(
654+
*this,
643655
group_count,
644656
tensor_sizes,
645657
p_w_norms,

orttraining/orttraining/training_ops/cuda/optimizer/lamb.cu

Lines changed: 120 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
// Licensed under the MIT License.
33

44
#include "core/providers/cuda/cu_inc/common.cuh"
5+
#include "core/providers/cuda/cuda_allocator.h"
56
#include "core/providers/cuda/cuda_common.h"
67
#include "core/providers/cuda/atomic/common.cuh"
8+
#include "core/providers/cuda/reduction/reduction_utils.cuh"
79
#include "orttraining/training_ops/cuda/math/isfinite.cuh"
10+
#include "orttraining/training_ops/cuda/optimizer/common.h"
811
#include "orttraining/training_ops/cuda/optimizer/common.cuh"
912
#include "orttraining/training_ops/cuda/optimizer/lamb.h"
1013

@@ -50,8 +53,8 @@ __device__ __forceinline__ void _LambComputeDirectionRule(
5053
const T3 m2_new_tmp_corrected = m2_new_tmp / beta_correction;
5154

5255
// Save regularized update direction to output.
53-
const T2 d_tmp = lambda * w +
54-
T1(m1_new_tmp_corrected / (_Sqrt(m2_new_tmp_corrected) + epsilon));
56+
const T2 d_tmp = lambda * w +
57+
T1(m1_new_tmp_corrected / (_Sqrt(m2_new_tmp_corrected) + epsilon));
5558

5659
// Things are updated only if the direction is finite.
5760
if (_IsFiniteScalar(d_tmp)) {
@@ -145,22 +148,22 @@ void LambComputeDirection(
145148
}
146149

147150
#define SPECIALIZED_LAMB_COMPUTE_DIRECTION(T1, T2, T3, T_GRAD_NORM) \
148-
template void LambComputeDirection( \
149-
const T1* weights, \
150-
const T2* grads, \
151-
const T3* moment_1, \
152-
const T3* moment_2, \
153-
const T1* loss_scale, \
154-
const T_GRAD_NORM* grad_norm, \
155-
T3 alpha, \
156-
T3 beta, \
157-
T1 lambda, \
158-
T3 epsilon, \
159-
T3 alpha_correction, \
160-
T3 beta_correction, \
161-
T2* weights_out, \
162-
T3* moment_1_out, \
163-
T3* moment_2_out, \
151+
template void LambComputeDirection( \
152+
const T1* weights, \
153+
const T2* grads, \
154+
const T3* moment_1, \
155+
const T3* moment_2, \
156+
const T1* loss_scale, \
157+
const T_GRAD_NORM* grad_norm, \
158+
T3 alpha, \
159+
T3 beta, \
160+
T1 lambda, \
161+
T3 epsilon, \
162+
T3 alpha_correction, \
163+
T3 beta_correction, \
164+
T2* weights_out, \
165+
T3* moment_1_out, \
166+
T3* moment_2_out, \
164167
size_t count);
165168

166169
SPECIALIZED_LAMB_COMPUTE_DIRECTION(float, float, float, float)
@@ -182,9 +185,8 @@ __device__ __forceinline__ void _LambUpdateRule(
182185
T2* w_new,
183186
T3* g_new,
184187
T_MIXED_PRECISION_FP* w_mixed_precision_new) {
185-
// Confidence coefficeint of this update.
186-
const T2 ratio = (w_norm != T2(0.0f) && r_norm != T2(0.0f)) ?
187-
T2(eta) * _Max(T2(ratio_min), _Min(T2(ratio_max), _Sqrt(w_norm / r_norm))) : T2(eta);
188+
// Confidence coefficeint of this update.
189+
const T2 ratio = (w_norm != T2(0.0f) && r_norm != T2(0.0f)) ? T2(eta) * _Max(T2(ratio_min), _Min(T2(ratio_max), _Sqrt(w_norm / r_norm))) : T2(eta);
188190

189191
// Compute delta using the saved update direction.
190192
const T2 delta = -ratio * T2(d);
@@ -313,7 +315,7 @@ __global__ void LambMultiTensorComputeDirectionImpl(
313315
T3* m2_new = reinterpret_cast<T3*>(chunk_group.tensor_ptrs[5][group_index]) + chunk_start;
314316
const T1 scale = _ComputeGradScale<T1, T_GRAD_NORM, T1>(loss_scale, g_norm);
315317

316-
#pragma unroll
318+
#pragma unroll
317319
for (int i = threadIdx.x; i < chunk_size && i + chunk_start < tensor_size; i += blockDim.x) {
318320
_LambComputeDirectionRule(
319321
scale,
@@ -359,16 +361,16 @@ void LambMultiTensorComputeDirectionFunctor<T1, T2, T3, T_GRAD_NORM>::operator()
359361
beta_correction);
360362
}
361363

362-
#define INSTANTIATE_LAMB_STAGE1_MULTI_TENSOR_FUNCTOR(T1, T2, T3, T_GRAD_NORM) \
364+
#define INSTANTIATE_LAMB_STAGE1_MULTI_TENSOR_FUNCTOR(T1, T2, T3, T_GRAD_NORM) \
363365
template void LambMultiTensorComputeDirectionFunctor<T1, T2, T3, T_GRAD_NORM>::operator()( \
364-
ChunkGroup<6> chunk_group, \
365-
const T1* loss_scale, \
366-
const T_GRAD_NORM* g_norm, \
367-
const T1 lambda, \
368-
const T3 alpha, \
369-
const T3 beta, \
370-
const T3 epsilon, \
371-
const T3 alpha_correction, \
366+
ChunkGroup<6> chunk_group, \
367+
const T1* loss_scale, \
368+
const T_GRAD_NORM* g_norm, \
369+
const T1 lambda, \
370+
const T3 alpha, \
371+
const T3 beta, \
372+
const T3 epsilon, \
373+
const T3 alpha_correction, \
372374
const T3 beta_correction);
373375

374376
INSTANTIATE_LAMB_STAGE1_MULTI_TENSOR_FUNCTOR(float, float, float, float)
@@ -440,9 +442,15 @@ INSTANTIATE_LAMB_MULTI_TENSOR_UPDATE_FUNCTOR(double, double, double, half)
440442
INSTANTIATE_LAMB_MULTI_TENSOR_UPDATE_FUNCTOR(half, float, half, half)
441443
INSTANTIATE_LAMB_MULTI_TENSOR_UPDATE_FUNCTOR(float, float, half, half)
442444

445+
// w_buffer[i], d_buffer[i] is used to store the squared sum of all elements processed by the i-th block.
446+
// sync_range_and_lock is used for a well ordered reduction over blocks spanning the same tensor
443447
template <typename TIn1, typename TIn2, typename TOut1, typename TOut2, typename TBuf>
444448
__launch_bounds__(ChunkGroup<4>::thread_count_per_block)
445-
__global__ void LambMultiTensorReductionImpl(ChunkGroup<4> chunk_group) {
449+
__global__ void LambMultiTensorReductionImpl(
450+
ChunkGroup<4> chunk_group,
451+
TOut1* w_buffer,
452+
TOut2* d_buffer,
453+
LambMultiTensorSyncRangeAndLock* sync_range_and_lock) {
446454
const int group_index = chunk_group.block_index_to_tensor_group_index[blockIdx.x];
447455
const int tensor_size = chunk_group.tensor_sizes[group_index];
448456
const int chunk_size = chunk_group.chunk_size;
@@ -469,7 +477,7 @@ __global__ void LambMultiTensorReductionImpl(ChunkGroup<4> chunk_group) {
469477
}
470478
}
471479

472-
// Thread count in a block must be a multiple of GPU_WARP_SIZE.
480+
// Thread count in a block must be a multiple of GPU_WARP_SIZE.
473481
#pragma unroll
474482
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
475483
w_sum += WARP_SHFL_DOWN(w_sum, stride);
@@ -502,14 +510,77 @@ __global__ void LambMultiTensorReductionImpl(ChunkGroup<4> chunk_group) {
502510
__syncthreads();
503511
}
504512

513+
// ascertain the range of blocks with the associated tensor
514+
// note: if non-ordered reduction is OK, then atomicAdd over blocks could suffice
515+
const int leading_block_in_tensor = sync_range_and_lock[group_index].leading_block;
516+
const int num_blocks_in_tensor = sync_range_and_lock[group_index].number_blocks;
517+
518+
if (num_blocks_in_tensor == 1) {
519+
if (threadIdx.x == 0) {
520+
*w_norm = TOut1(w_shared_memory_[0]);
521+
*d_norm = TOut2(d_shared_memory_[0]);
522+
}
523+
return;
524+
}
525+
505526
if (threadIdx.x == 0) {
506-
atomic_add(w_norm, TOut1(w_shared_memory_[0]));
507-
atomic_add(d_norm, TOut2(d_shared_memory_[0]));
527+
w_buffer[blockIdx.x] = w_shared_memory_[0];
528+
d_buffer[blockIdx.x] = d_shared_memory_[0];
508529
}
530+
531+
__threadfence();
532+
__syncthreads();
533+
534+
// use lock to determine if this is last block for given tensor
535+
__shared__ bool is_last_block_done;
536+
537+
if (threadIdx.x == 0) {
538+
int* p_lock = &sync_range_and_lock[group_index].completed_blocks;
539+
int counter = atomicAdd(p_lock, 1);
540+
is_last_block_done = (counter == num_blocks_in_tensor - 1);
541+
}
542+
__syncthreads();
543+
544+
// only last block to finish for associated tensor enters below
545+
if (is_last_block_done) {
546+
const int pow2_bound = least_pow2_bound(num_blocks_in_tensor);
547+
int blockid = leading_block_in_tensor + threadIdx.x;
548+
for (int stride = pow2_bound / 2; stride > 0; stride /= 2) {
549+
if (threadIdx.x < stride && threadIdx.x + stride < num_blocks_in_tensor) {
550+
w_buffer[blockid] += w_buffer[blockid + stride];
551+
d_buffer[blockid] += d_buffer[blockid + stride];
552+
}
553+
__syncthreads();
554+
}
555+
556+
if (threadIdx.x == 0) {
557+
*w_norm = TOut1(w_buffer[leading_block_in_tensor]);
558+
*d_norm = TOut2(d_buffer[leading_block_in_tensor]);
559+
}
560+
}
561+
}
562+
563+
CudaKernel::CudaAsyncBuffer<LambMultiTensorSyncRangeAndLock> compute_tensor_range_and_lock(ChunkGroup<4> chunk_group, const CudaKernel& kernel) {
564+
const int num_blocks = chunk_group.chunk_count;
565+
566+
// sync_range_and_lock is a struct consisting of (start_block, num_blocks, lock) for each tensor
567+
// Note: Adding such info to chunk group causes overflow (unless max tensors is reduced)
568+
const int max_tensors = ChunkGroup<4>::max_tensor_group_count;
569+
LambMultiTensorSyncRangeAndLock initial = {0, 0, 0};
570+
CudaKernel::CudaAsyncBuffer<LambMultiTensorSyncRangeAndLock> sync_range_and_lock(&kernel, initial, max_tensors);
571+
for (int block_index = num_blocks - 1; block_index >= 0; block_index--) {
572+
int tensor_index = chunk_group.block_index_to_tensor_group_index[block_index];
573+
auto& tensor_block_span = sync_range_and_lock.CpuPtr()[tensor_index];
574+
tensor_block_span.leading_block = block_index;
575+
tensor_block_span.number_blocks++;
576+
}
577+
sync_range_and_lock.CopyToGpu();
578+
579+
return sync_range_and_lock;
509580
}
510581

511582
template <typename TIn1, typename TIn2, typename TOut1, typename TOut2, typename TBuf>
512-
void LambMultiTensorReductionFunctor<TIn1, TIn2, TOut1, TOut2, TBuf>::operator()(ChunkGroup<4> chunk_group) {
583+
void LambMultiTensorReductionFunctor<TIn1, TIn2, TOut1, TOut2, TBuf>::operator()(ChunkGroup<4> chunk_group, const CudaKernel& kernel, void* reduction_buffer, size_t reduction_buffer_size) {
513584
// thread count per block.
514585
constexpr int thread_count = ChunkGroup<4>::thread_count_per_block;
515586
// shared memory's size per block.
@@ -519,11 +590,22 @@ void LambMultiTensorReductionFunctor<TIn1, TIn2, TOut1, TOut2, TBuf>::operator()
519590
ORT_ENFORCE(thread_count % GPU_WARP_SIZE == 0);
520591
ORT_ENFORCE((thread_count & (thread_count - 1)) == 0);
521592

522-
LambMultiTensorReductionImpl<TIn1, TIn2, TOut1, TOut2, TBuf><<<chunk_group.chunk_count, thread_count, shared_memory_size>>>(chunk_group);
593+
const int num_blocks = chunk_group.chunk_count;
594+
const size_t w_buffer_size = num_blocks * sizeof(TOut1);
595+
const size_t d_buffer_size = num_blocks * sizeof(TOut2);
596+
597+
ORT_ENFORCE(w_buffer_size + d_buffer_size <= reduction_buffer_size);
598+
599+
TOut1* w_buffer = reinterpret_cast<TOut1*>(reduction_buffer);
600+
TOut2* d_buffer = reinterpret_cast<TOut2*>(w_buffer + num_blocks);
601+
602+
auto sync_range_and_lock = compute_tensor_range_and_lock(chunk_group, kernel);
603+
LambMultiTensorReductionImpl<TIn1, TIn2, TOut1, TOut2, TBuf><<<chunk_group.chunk_count, thread_count, shared_memory_size>>>(
604+
chunk_group, w_buffer, d_buffer, sync_range_and_lock.GpuPtr());
523605
}
524606

525607
#define INSTANTIATE_LAMB_MULTI_TENSOR_REDUCTION_FUNCTOR(TIn1, TIn2, TOut1, TOut2, TBuf) \
526-
template void LambMultiTensorReductionFunctor<TIn1, TIn2, TOut1, TOut2, TBuf>::operator()(ChunkGroup<4> chunk_group);
608+
template void LambMultiTensorReductionFunctor<TIn1, TIn2, TOut1, TOut2, TBuf>::operator()(ChunkGroup<4> chunk_group, const CudaKernel& kernel, void* reduction_buffer, size_t reduction_buffer_size);
527609

528610
INSTANTIATE_LAMB_MULTI_TENSOR_REDUCTION_FUNCTOR(float, float, float, float, float)
529611
INSTANTIATE_LAMB_MULTI_TENSOR_REDUCTION_FUNCTOR(double, double, double, double, double)

0 commit comments

Comments
 (0)