22// Licensed under the MIT License.
33
44#include " core/providers/cuda/cu_inc/common.cuh"
5+ #include " core/providers/cuda/cuda_allocator.h"
56#include " core/providers/cuda/cuda_common.h"
67#include " core/providers/cuda/atomic/common.cuh"
8+ #include " core/providers/cuda/reduction/reduction_utils.cuh"
79#include " orttraining/training_ops/cuda/math/isfinite.cuh"
10+ #include " orttraining/training_ops/cuda/optimizer/common.h"
811#include " orttraining/training_ops/cuda/optimizer/common.cuh"
912#include " orttraining/training_ops/cuda/optimizer/lamb.h"
1013
@@ -50,8 +53,8 @@ __device__ __forceinline__ void _LambComputeDirectionRule(
5053 const T3 m2_new_tmp_corrected = m2_new_tmp / beta_correction;
5154
5255 // Save regularized update direction to output.
53- const T2 d_tmp = lambda * w +
54- T1 (m1_new_tmp_corrected / (_Sqrt (m2_new_tmp_corrected) + epsilon));
56+ const T2 d_tmp = lambda * w +
57+ T1 (m1_new_tmp_corrected / (_Sqrt (m2_new_tmp_corrected) + epsilon));
5558
5659 // Things are updated only if the direction is finite.
5760 if (_IsFiniteScalar (d_tmp)) {
@@ -145,22 +148,22 @@ void LambComputeDirection(
145148}
146149
147150#define SPECIALIZED_LAMB_COMPUTE_DIRECTION (T1, T2, T3, T_GRAD_NORM ) \
148- template void LambComputeDirection ( \
149- const T1* weights, \
150- const T2* grads, \
151- const T3* moment_1, \
152- const T3* moment_2, \
153- const T1* loss_scale, \
154- const T_GRAD_NORM* grad_norm, \
155- T3 alpha, \
156- T3 beta, \
157- T1 lambda, \
158- T3 epsilon, \
159- T3 alpha_correction, \
160- T3 beta_correction, \
161- T2* weights_out, \
162- T3* moment_1_out, \
163- T3* moment_2_out, \
151+ template void LambComputeDirection ( \
152+ const T1* weights, \
153+ const T2* grads, \
154+ const T3* moment_1, \
155+ const T3* moment_2, \
156+ const T1* loss_scale, \
157+ const T_GRAD_NORM* grad_norm, \
158+ T3 alpha, \
159+ T3 beta, \
160+ T1 lambda, \
161+ T3 epsilon, \
162+ T3 alpha_correction, \
163+ T3 beta_correction, \
164+ T2* weights_out, \
165+ T3* moment_1_out, \
166+ T3* moment_2_out, \
164167 size_t count);
165168
166169SPECIALIZED_LAMB_COMPUTE_DIRECTION (float , float , float , float )
@@ -182,9 +185,8 @@ __device__ __forceinline__ void _LambUpdateRule(
182185 T2* w_new,
183186 T3* g_new,
184187 T_MIXED_PRECISION_FP* w_mixed_precision_new) {
185- // Confidence coefficeint of this update.
186- const T2 ratio = (w_norm != T2 (0 .0f ) && r_norm != T2 (0 .0f )) ?
187- T2 (eta) * _Max (T2 (ratio_min), _Min (T2 (ratio_max), _Sqrt (w_norm / r_norm))) : T2 (eta);
188+ // Confidence coefficeint of this update.
189+ const T2 ratio = (w_norm != T2 (0 .0f ) && r_norm != T2 (0 .0f )) ? T2 (eta) * _Max (T2 (ratio_min), _Min (T2 (ratio_max), _Sqrt (w_norm / r_norm))) : T2 (eta);
188190
189191 // Compute delta using the saved update direction.
190192 const T2 delta = -ratio * T2 (d);
@@ -313,7 +315,7 @@ __global__ void LambMultiTensorComputeDirectionImpl(
313315 T3* m2_new = reinterpret_cast <T3*>(chunk_group.tensor_ptrs [5 ][group_index]) + chunk_start;
314316 const T1 scale = _ComputeGradScale<T1, T_GRAD_NORM, T1>(loss_scale, g_norm);
315317
316- #pragma unroll
318+ #pragma unroll
317319 for (int i = threadIdx .x ; i < chunk_size && i + chunk_start < tensor_size; i += blockDim .x ) {
318320 _LambComputeDirectionRule (
319321 scale,
@@ -359,16 +361,16 @@ void LambMultiTensorComputeDirectionFunctor<T1, T2, T3, T_GRAD_NORM>::operator()
359361 beta_correction);
360362}
361363
362- #define INSTANTIATE_LAMB_STAGE1_MULTI_TENSOR_FUNCTOR (T1, T2, T3, T_GRAD_NORM ) \
364+ #define INSTANTIATE_LAMB_STAGE1_MULTI_TENSOR_FUNCTOR (T1, T2, T3, T_GRAD_NORM ) \
363365 template void LambMultiTensorComputeDirectionFunctor<T1, T2, T3, T_GRAD_NORM>::operator ()( \
364- ChunkGroup<6 > chunk_group, \
365- const T1* loss_scale, \
366- const T_GRAD_NORM* g_norm, \
367- const T1 lambda, \
368- const T3 alpha, \
369- const T3 beta, \
370- const T3 epsilon, \
371- const T3 alpha_correction, \
366+ ChunkGroup<6 > chunk_group, \
367+ const T1* loss_scale, \
368+ const T_GRAD_NORM* g_norm, \
369+ const T1 lambda, \
370+ const T3 alpha, \
371+ const T3 beta, \
372+ const T3 epsilon, \
373+ const T3 alpha_correction, \
372374 const T3 beta_correction);
373375
374376INSTANTIATE_LAMB_STAGE1_MULTI_TENSOR_FUNCTOR (float , float , float , float )
@@ -440,9 +442,15 @@ INSTANTIATE_LAMB_MULTI_TENSOR_UPDATE_FUNCTOR(double, double, double, half)
440442INSTANTIATE_LAMB_MULTI_TENSOR_UPDATE_FUNCTOR(half, float , half, half)
441443INSTANTIATE_LAMB_MULTI_TENSOR_UPDATE_FUNCTOR(float , float , half, half)
442444
445+ // w_buffer[i], d_buffer[i] is used to store the squared sum of all elements processed by the i-th block.
446+ // sync_range_and_lock is used for a well ordered reduction over blocks spanning the same tensor
443447template <typename TIn1, typename TIn2, typename TOut1, typename TOut2, typename TBuf>
444448__launch_bounds__(ChunkGroup<4 >::thread_count_per_block)
445- __global__ void LambMultiTensorReductionImpl(ChunkGroup<4 > chunk_group) {
449+ __global__ void LambMultiTensorReductionImpl(
450+ ChunkGroup<4 > chunk_group,
451+ TOut1* w_buffer,
452+ TOut2* d_buffer,
453+ LambMultiTensorSyncRangeAndLock* sync_range_and_lock) {
446454 const int group_index = chunk_group.block_index_to_tensor_group_index [blockIdx .x ];
447455 const int tensor_size = chunk_group.tensor_sizes [group_index];
448456 const int chunk_size = chunk_group.chunk_size ;
@@ -469,7 +477,7 @@ __global__ void LambMultiTensorReductionImpl(ChunkGroup<4> chunk_group) {
469477 }
470478 }
471479
472- // Thread count in a block must be a multiple of GPU_WARP_SIZE.
480+ // Thread count in a block must be a multiple of GPU_WARP_SIZE.
473481#pragma unroll
474482 for (int stride = GPU_WARP_SIZE / 2 ; stride > 0 ; stride /= 2 ) {
475483 w_sum += WARP_SHFL_DOWN (w_sum, stride);
@@ -502,14 +510,77 @@ __global__ void LambMultiTensorReductionImpl(ChunkGroup<4> chunk_group) {
502510 __syncthreads ();
503511 }
504512
513+ // ascertain the range of blocks with the associated tensor
514+ // note: if non-ordered reduction is OK, then atomicAdd over blocks could suffice
515+ const int leading_block_in_tensor = sync_range_and_lock[group_index].leading_block ;
516+ const int num_blocks_in_tensor = sync_range_and_lock[group_index].number_blocks ;
517+
518+ if (num_blocks_in_tensor == 1 ) {
519+ if (threadIdx .x == 0 ) {
520+ *w_norm = TOut1 (w_shared_memory_[0 ]);
521+ *d_norm = TOut2 (d_shared_memory_[0 ]);
522+ }
523+ return ;
524+ }
525+
505526 if (threadIdx .x == 0 ) {
506- atomic_add (w_norm, TOut1 ( w_shared_memory_[0 ])) ;
507- atomic_add (d_norm, TOut2 ( d_shared_memory_[0 ])) ;
527+ w_buffer[ blockIdx . x ] = w_shared_memory_[0 ];
528+ d_buffer[ blockIdx . x ] = d_shared_memory_[0 ];
508529 }
530+
531+ __threadfence ();
532+ __syncthreads ();
533+
534+ // use lock to determine if this is last block for given tensor
535+ __shared__ bool is_last_block_done;
536+
537+ if (threadIdx .x == 0 ) {
538+ int * p_lock = &sync_range_and_lock[group_index].completed_blocks ;
539+ int counter = atomicAdd (p_lock, 1 );
540+ is_last_block_done = (counter == num_blocks_in_tensor - 1 );
541+ }
542+ __syncthreads ();
543+
544+ // only last block to finish for associated tensor enters below
545+ if (is_last_block_done) {
546+ const int pow2_bound = least_pow2_bound (num_blocks_in_tensor);
547+ int blockid = leading_block_in_tensor + threadIdx .x ;
548+ for (int stride = pow2_bound / 2 ; stride > 0 ; stride /= 2 ) {
549+ if (threadIdx .x < stride && threadIdx .x + stride < num_blocks_in_tensor) {
550+ w_buffer[blockid] += w_buffer[blockid + stride];
551+ d_buffer[blockid] += d_buffer[blockid + stride];
552+ }
553+ __syncthreads ();
554+ }
555+
556+ if (threadIdx .x == 0 ) {
557+ *w_norm = TOut1 (w_buffer[leading_block_in_tensor]);
558+ *d_norm = TOut2 (d_buffer[leading_block_in_tensor]);
559+ }
560+ }
561+ }
562+
563+ CudaKernel::CudaAsyncBuffer<LambMultiTensorSyncRangeAndLock> compute_tensor_range_and_lock (ChunkGroup<4 > chunk_group, const CudaKernel& kernel) {
564+ const int num_blocks = chunk_group.chunk_count ;
565+
566+ // sync_range_and_lock is a struct consisting of (start_block, num_blocks, lock) for each tensor
567+ // Note: Adding such info to chunk group causes overflow (unless max tensors is reduced)
568+ const int max_tensors = ChunkGroup<4 >::max_tensor_group_count;
569+ LambMultiTensorSyncRangeAndLock initial = {0 , 0 , 0 };
570+ CudaKernel::CudaAsyncBuffer<LambMultiTensorSyncRangeAndLock> sync_range_and_lock (&kernel, initial, max_tensors);
571+ for (int block_index = num_blocks - 1 ; block_index >= 0 ; block_index--) {
572+ int tensor_index = chunk_group.block_index_to_tensor_group_index [block_index];
573+ auto & tensor_block_span = sync_range_and_lock.CpuPtr ()[tensor_index];
574+ tensor_block_span.leading_block = block_index;
575+ tensor_block_span.number_blocks ++;
576+ }
577+ sync_range_and_lock.CopyToGpu ();
578+
579+ return sync_range_and_lock;
509580}
510581
511582template <typename TIn1, typename TIn2, typename TOut1, typename TOut2, typename TBuf>
512- void LambMultiTensorReductionFunctor<TIn1, TIn2, TOut1, TOut2, TBuf>::operator ()(ChunkGroup<4 > chunk_group) {
583+ void LambMultiTensorReductionFunctor<TIn1, TIn2, TOut1, TOut2, TBuf>::operator ()(ChunkGroup<4 > chunk_group, const CudaKernel& kernel, void * reduction_buffer, size_t reduction_buffer_size ) {
513584 // thread count per block.
514585 constexpr int thread_count = ChunkGroup<4 >::thread_count_per_block;
515586 // shared memory's size per block.
@@ -519,11 +590,22 @@ void LambMultiTensorReductionFunctor<TIn1, TIn2, TOut1, TOut2, TBuf>::operator()
519590 ORT_ENFORCE (thread_count % GPU_WARP_SIZE == 0 );
520591 ORT_ENFORCE ((thread_count & (thread_count - 1 )) == 0 );
521592
522- LambMultiTensorReductionImpl<TIn1, TIn2, TOut1, TOut2, TBuf><<<chunk_group.chunk_count, thread_count, shared_memory_size>>> (chunk_group);
593+ const int num_blocks = chunk_group.chunk_count ;
594+ const size_t w_buffer_size = num_blocks * sizeof (TOut1);
595+ const size_t d_buffer_size = num_blocks * sizeof (TOut2);
596+
597+ ORT_ENFORCE (w_buffer_size + d_buffer_size <= reduction_buffer_size);
598+
599+ TOut1* w_buffer = reinterpret_cast <TOut1*>(reduction_buffer);
600+ TOut2* d_buffer = reinterpret_cast <TOut2*>(w_buffer + num_blocks);
601+
602+ auto sync_range_and_lock = compute_tensor_range_and_lock (chunk_group, kernel);
603+ LambMultiTensorReductionImpl<TIn1, TIn2, TOut1, TOut2, TBuf><<<chunk_group.chunk_count, thread_count, shared_memory_size>>> (
604+ chunk_group, w_buffer, d_buffer, sync_range_and_lock.GpuPtr ());
523605}
524606
525607#define INSTANTIATE_LAMB_MULTI_TENSOR_REDUCTION_FUNCTOR (TIn1, TIn2, TOut1, TOut2, TBuf ) \
526- template void LambMultiTensorReductionFunctor<TIn1, TIn2, TOut1, TOut2, TBuf>::operator ()(ChunkGroup<4 > chunk_group);
608+ template void LambMultiTensorReductionFunctor<TIn1, TIn2, TOut1, TOut2, TBuf>::operator ()(ChunkGroup<4 > chunk_group, const CudaKernel& kernel, void * reduction_buffer, size_t reduction_buffer_size );
527609
528610INSTANTIATE_LAMB_MULTI_TENSOR_REDUCTION_FUNCTOR (float , float , float , float , float )
529611INSTANTIATE_LAMB_MULTI_TENSOR_REDUCTION_FUNCTOR(double , double , double , double , double )
0 commit comments