@@ -753,59 +753,75 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
753
753
"""
754
754
split_precomputation = """
755
755
at::acc_type<cache_t, true> freq = 1.0;
756
- at::acc_type<cache_t, true> l2_wd = 0.0;
757
756
at::acc_type<cache_t, true> tail_id_threshold_val = tail_id_threshold;
758
- CUDA_KERNEL_ASSERT(max_counter > 0.0); // avoid divide by zero error
757
+ CUDA_KERNEL_ASSERT(max_counter != 0.0); // avoid divide by zero error
759
758
if (is_tail_id_thresh_ratio == 1){
760
759
tail_id_threshold_val = floorf(tail_id_threshold * max_counter);
761
760
}
762
- if (counter_halflife > 0 && threadIdx.x == 0) {
763
- // if id occurs multiple times in a batch, iter_delta=1
764
- const auto iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
765
- prev_iter[idx] = iter * 1.0;
766
- const auto counter_log_rho = logf(2.0) / counter_halflife;
767
- row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx];
768
- freq = counter_halflife / row_counter[idx];
769
- if (weight_decay_mode == 1) {
770
- // L2 regularization
771
- l2_wd = 1.0;
761
+ if (threadIdx.x == 0) {
762
+ if (counter_halflife > 0) { // decay based on counter_halflife
763
+ // if id occurs multiple times in a batch, iter_delta=1
764
+ const auto iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
765
+ prev_iter[idx] = iter * 1.0;
766
+ const auto counter_log_rho = logf(2.0) / counter_halflife;
767
+ row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx];
768
+ } else if (counter_halflife == 0) { // count only 1 (appear or not)
769
+ row_counter[idx] = 1.0;
770
+ } else { // count raw appearance without decaying
771
+ row_counter[idx] += 1.0;
772
772
}
773
+ freq = counter_halflife / row_counter[idx];
773
774
}
774
775
freq = SHFL_SYNC(freq, 0);
775
- l2_wd = SHFL_SYNC(l2_wd, 0);
776
776
tail_id_threshold_val = SHFL_SYNC(tail_id_threshold_val, 0);
777
777
778
778
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
779
+ at::acc_type<cache_t, true> w_local_sum_square = 0.0;
779
780
780
781
#pragma unroll kMaxVecsPerThread
781
782
for (int32_t i = 0;
782
783
i < kMaxVecsPerThread && 4 * kThreadGroupSize * i + threadIdx.x * 4 < D;
783
784
++i) {
785
+ auto gx = grad_sum[i].acc.x;
786
+ auto gy = grad_sum[i].acc.y;
787
+ auto gz = grad_sum[i].acc.z;
788
+ auto gw = grad_sum[i].acc.w;
789
+
784
790
int32_t d = 4 * kThreadGroupSize * i + threadIdx.x * 4;
785
791
Vec4T<at::acc_type<cache_t, true>> weight = weight_row_template.load(d, qparams_template);
786
- auto gx = grad_sum[i].acc.x + l2_wd * freq * weight_decay * weight.acc.x;
787
- auto gy = grad_sum[i].acc.y + l2_wd * freq * weight_decay * weight.acc.y;
788
- auto gz = grad_sum[i].acc.z + l2_wd * freq * weight_decay * weight.acc.z;
789
- auto gw = grad_sum[i].acc.w + l2_wd * freq * weight_decay * weight.acc.w;
792
+
793
+ if (weight_decay_mode == 1) {
794
+ // L2 regularization
795
+ gx += weight_decay * weight.acc.x;
796
+ gy += weight_decay * weight.acc.y;
797
+ gz += weight_decay * weight.acc.z;
798
+ gw += weight_decay * weight.acc.w;
799
+ }
790
800
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
801
+
802
+ if (regularization_mode == 4) { // cow_clip requires weight norm
803
+ w_local_sum_square += weight.acc.x * weight.acc.x + weight.acc.y * weight.acc.y + weight.acc.z * weight.acc.z + weight.acc.w * weight.acc.w;
804
+ }
791
805
}
792
806
793
- const at::acc_type<cache_t, true> g_avg_square =
794
- warpReduceAllSum<at::acc_type<cache_t, true>, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask) / D;
807
+ const at::acc_type<cache_t, true> g_sum_square =
808
+ warpReduceAllSum<at::acc_type<cache_t, true>, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask);
809
+ const at::acc_type<cache_t, true> g_avg_square = g_sum_square / D;
810
+ const at::acc_type<cache_t, true> w_sum_square =
811
+ warpReduceAllSum<at::acc_type<cache_t, true>, kThreadGroupSize>(w_local_sum_square, shfl_sync_mask);
795
812
796
- at::acc_type<cache_t, true> multiplier;
797
813
at::acc_type<cache_t, true> adjusted_multiplier;
798
814
at::acc_type<cache_t, true> exp_reg_correction;
799
815
800
816
if (threadIdx.x == 0) {
801
817
at::acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + g_avg_square;
802
818
momentum1[idx] = new_sum_square_grads;
803
- multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
804
-
805
- adjusted_multiplier = multiplier;
806
- if ( learning_rate_mode >=0 ) {
807
- if (adjustment_iter <= 0 || (adjustment_iter > 0 && iter > adjustment_iter)) {
819
+ const auto multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
820
+ const auto adjustment_enabled = adjustment_iter <= 0 || (adjustment_iter > 0 && iter > adjustment_iter);
808
821
822
+ if (regularization_mode == 3) {
823
+ adjusted_multiplier = multiplier;
824
+ if ( learning_rate_mode >=0 && adjustment_enabled) {
809
825
if (row_counter[idx] > tail_id_threshold_val) {
810
826
if ( learning_rate_mode == 0 ) {
811
827
adjusted_multiplier = multiplier * max(min(powf(max_counter/(row_counter[idx] + 1.0), adjustment_ub), 10.0), 1.0);
@@ -816,20 +832,32 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
816
832
}
817
833
}
818
834
}
835
+ } else if (regularization_mode == 4) {
836
+ const auto clip_thresh = row_counter[idx] * max(weight_norm_coefficient * sqrtf(w_sum_square), lower_bound);
837
+ adjusted_multiplier = min(1.0f, clip_thresh / sqrtf(g_sum_square)) * multiplier;
819
838
}
820
839
821
840
exp_reg_correction = 1.0;
822
- if (adjustment_iter <= 0 || (adjustment_iter > 0 && iter > adjustment_iter)) {
841
+ if (regularization_mode == 3) {
842
+ if (adjustment_enabled) {
843
+ if (weight_decay_mode == 2) {
844
+ // Decoupled weight decay
845
+ exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
846
+ } else if (weight_decay_mode == 1) {
847
+ // L2 regularization (coupled wd)
848
+ exp_reg_correction = 1.0 - freq * weight_decay * multiplier;
849
+ }
850
+ }
851
+ } else if (regularization_mode == 4) {
823
852
if (weight_decay_mode == 2) {
824
853
// Decoupled weight decay
825
- exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
854
+ exp_reg_correction = 1.0 - weight_decay * learning_rate;
826
855
} else if (weight_decay_mode == 1) {
827
856
// L2 regularization (coupled wd)
828
- exp_reg_correction = 1.0 - freq * weight_decay * multiplier ;
857
+ exp_reg_correction = 1.0 - weight_decay * adjusted_multiplier ;
829
858
}
830
859
}
831
860
}
832
- multiplier = SHFL_SYNC(multiplier, 0);
833
861
adjusted_multiplier = SHFL_SYNC(adjusted_multiplier, 0);
834
862
exp_reg_correction = SHFL_SYNC(exp_reg_correction, 0);
835
863
"""
@@ -874,6 +902,9 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
874
902
(FLOAT , "max_counter" ),
875
903
(FLOAT , "tail_id_threshold" , 0.0 ),
876
904
(INT , "is_tail_id_thresh_ratio" , 0 ),
905
+ (INT , "regularization_mode" , 0 ),
906
+ (FLOAT , "weight_norm_coefficient" , 0.0 ),
907
+ (FLOAT , "lower_bound" , 0.0 ),
877
908
]
878
909
),
879
910
"split_precomputation" : split_precomputation ,
@@ -891,7 +922,7 @@ def approx_rowwise_adagrad_with_counter() -> Dict[str, Any]:
891
922
892
923
approx_split_weight_update = """
893
924
// dummy computation to avoid unused variable warning
894
- weight_new.fma_(grad, -multiplier );
925
+ weight_new.fma_(grad, -0.001 );
895
926
assert(false); // approx rowwise AdaGrad is not supported on GPU
896
927
"""
897
928
@@ -915,6 +946,9 @@ def approx_rowwise_adagrad_with_counter() -> Dict[str, Any]:
915
946
(FLOAT , "max_counter" ),
916
947
(FLOAT , "tail_id_threshold" , 0.0 ),
917
948
(INT , "is_tail_id_thresh_ratio" , 0 ),
949
+ (INT , "regularization_mode" , 0 ),
950
+ (FLOAT , "weight_norm_coefficient" , 0.0 ),
951
+ (FLOAT , "lower_bound" , 0.0 ),
918
952
]
919
953
),
920
954
"split_precomputation" : rowwise_adagrad_with_counter_args [
0 commit comments