Skip to content

Commit fd0fe4f

Browse files
Wang Zhoufacebook-github-bot
authored andcommitted
FBGEMM kernel codegen
Summary: FBGEMM kernel implementation for CowClip optimizer (https://arxiv.org/pdf/2204.06240.pdf). It is based on counter-sgd to reuse the counter state. {F1183660363} Differential Revision: D52268946
1 parent a535f22 commit fd0fe4f

File tree

1 file changed

+64
-30
lines changed

1 file changed

+64
-30
lines changed

fbgemm_gpu/codegen/embedding_common_code_generator.py

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -753,59 +753,75 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
753753
"""
754754
split_precomputation = """
755755
at::acc_type<cache_t, true> freq = 1.0;
756-
at::acc_type<cache_t, true> l2_wd = 0.0;
757756
at::acc_type<cache_t, true> tail_id_threshold_val = tail_id_threshold;
758-
CUDA_KERNEL_ASSERT(max_counter > 0.0); // avoid divide by zero error
757+
CUDA_KERNEL_ASSERT(max_counter != 0.0); // avoid divide by zero error
759758
if (is_tail_id_thresh_ratio == 1){
760759
tail_id_threshold_val = floorf(tail_id_threshold * max_counter);
761760
}
762-
if (counter_halflife > 0 && threadIdx.x == 0) {
763-
// if id occurs multiple times in a batch, iter_delta=1
764-
const auto iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
765-
prev_iter[idx] = iter * 1.0;
766-
const auto counter_log_rho = logf(2.0) / counter_halflife;
767-
row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx];
768-
freq = counter_halflife / row_counter[idx];
769-
if (weight_decay_mode == 1) {
770-
// L2 regularization
771-
l2_wd = 1.0;
761+
if (threadIdx.x == 0) {
762+
if (counter_halflife > 0) { // decay based on counter_halflife
763+
// if id occurs multiple times in a batch, iter_delta=1
764+
const auto iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
765+
prev_iter[idx] = iter * 1.0;
766+
const auto counter_log_rho = logf(2.0) / counter_halflife;
767+
row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx];
768+
} else if (counter_halflife == 0) { // count only 1 (appear or not)
769+
row_counter[idx] = 1.0;
770+
} else { // count raw appearance without decaying
771+
row_counter[idx] += 1.0;
772772
}
773+
freq = counter_halflife / row_counter[idx];
773774
}
774775
freq = SHFL_SYNC(freq, 0);
775-
l2_wd = SHFL_SYNC(l2_wd, 0);
776776
tail_id_threshold_val = SHFL_SYNC(tail_id_threshold_val, 0);
777777
778778
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
779+
at::acc_type<cache_t, true> w_local_sum_square = 0.0;
779780
780781
#pragma unroll kMaxVecsPerThread
781782
for (int32_t i = 0;
782783
i < kMaxVecsPerThread && 4 * kThreadGroupSize * i + threadIdx.x * 4 < D;
783784
++i) {
785+
auto gx = grad_sum[i].acc.x;
786+
auto gy = grad_sum[i].acc.y;
787+
auto gz = grad_sum[i].acc.z;
788+
auto gw = grad_sum[i].acc.w;
789+
784790
int32_t d = 4 * kThreadGroupSize * i + threadIdx.x * 4;
785791
Vec4T<at::acc_type<cache_t, true>> weight = weight_row_template.load(d, qparams_template);
786-
auto gx = grad_sum[i].acc.x + l2_wd * freq * weight_decay * weight.acc.x;
787-
auto gy = grad_sum[i].acc.y + l2_wd * freq * weight_decay * weight.acc.y;
788-
auto gz = grad_sum[i].acc.z + l2_wd * freq * weight_decay * weight.acc.z;
789-
auto gw = grad_sum[i].acc.w + l2_wd * freq * weight_decay * weight.acc.w;
792+
793+
if (weight_decay_mode == 1) {
794+
// L2 regularization
795+
gx += weight_decay * weight.acc.x;
796+
gy += weight_decay * weight.acc.y;
797+
gz += weight_decay * weight.acc.z;
798+
gw += weight_decay * weight.acc.w;
799+
}
790800
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
801+
802+
if (regularization_mode == 4) { // cow_clip requires weight norm
803+
w_local_sum_square += weight.acc.x * weight.acc.x + weight.acc.y * weight.acc.y + weight.acc.z * weight.acc.z + weight.acc.w * weight.acc.w;
804+
}
791805
}
792806
793-
const at::acc_type<cache_t, true> g_avg_square =
794-
warpReduceAllSum<at::acc_type<cache_t, true>, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask) / D;
807+
const at::acc_type<cache_t, true> g_sum_square =
808+
warpReduceAllSum<at::acc_type<cache_t, true>, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask);
809+
const at::acc_type<cache_t, true> g_avg_square = g_sum_square / D;
810+
const at::acc_type<cache_t, true> w_sum_square =
811+
warpReduceAllSum<at::acc_type<cache_t, true>, kThreadGroupSize>(w_local_sum_square, shfl_sync_mask);
795812
796-
at::acc_type<cache_t, true> multiplier;
797813
at::acc_type<cache_t, true> adjusted_multiplier;
798814
at::acc_type<cache_t, true> exp_reg_correction;
799815
800816
if (threadIdx.x == 0) {
801817
at::acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + g_avg_square;
802818
momentum1[idx] = new_sum_square_grads;
803-
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
804-
805-
adjusted_multiplier = multiplier;
806-
if ( learning_rate_mode >=0 ) {
807-
if (adjustment_iter <= 0 || (adjustment_iter > 0 && iter > adjustment_iter)) {
819+
const auto multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
820+
const auto adjustment_enabled = adjustment_iter <= 0 || (adjustment_iter > 0 && iter > adjustment_iter);
808821
822+
if (regularization_mode == 3) {
823+
adjusted_multiplier = multiplier;
824+
if ( learning_rate_mode >=0 && adjustment_enabled) {
809825
if (row_counter[idx] > tail_id_threshold_val) {
810826
if ( learning_rate_mode == 0 ) {
811827
adjusted_multiplier = multiplier * max(min(powf(max_counter/(row_counter[idx] + 1.0), adjustment_ub), 10.0), 1.0);
@@ -816,20 +832,32 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
816832
}
817833
}
818834
}
835+
} else if (regularization_mode == 4) {
836+
const auto clip_thresh = row_counter[idx] * max(weight_norm_coefficient * sqrtf(w_sum_square), lower_bound);
837+
adjusted_multiplier = min(1.0f, clip_thresh / sqrtf(g_sum_square)) * multiplier;
819838
}
820839
821840
exp_reg_correction = 1.0;
822-
if (adjustment_iter <= 0 || (adjustment_iter > 0 && iter > adjustment_iter)) {
841+
if (regularization_mode == 3) {
842+
if (adjustment_enabled) {
843+
if (weight_decay_mode == 2) {
844+
// Decoupled weight decay
845+
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
846+
} else if (weight_decay_mode == 1) {
847+
// L2 regularization (coupled wd)
848+
exp_reg_correction = 1.0 - freq * weight_decay * multiplier;
849+
}
850+
}
851+
} else if (regularization_mode == 4) {
823852
if (weight_decay_mode == 2) {
824853
// Decoupled weight decay
825-
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
854+
exp_reg_correction = 1.0 - weight_decay * learning_rate;
826855
} else if (weight_decay_mode == 1) {
827856
// L2 regularization (coupled wd)
828-
exp_reg_correction = 1.0 - freq * weight_decay * multiplier;
857+
exp_reg_correction = 1.0 - weight_decay * adjusted_multiplier;
829858
}
830859
}
831860
}
832-
multiplier = SHFL_SYNC(multiplier, 0);
833861
adjusted_multiplier = SHFL_SYNC(adjusted_multiplier, 0);
834862
exp_reg_correction = SHFL_SYNC(exp_reg_correction, 0);
835863
"""
@@ -874,6 +902,9 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
874902
(FLOAT, "max_counter"),
875903
(FLOAT, "tail_id_threshold", 0.0),
876904
(INT, "is_tail_id_thresh_ratio", 0),
905+
(INT, "regularization_mode", 0),
906+
(FLOAT, "weight_norm_coefficient", 0.0),
907+
(FLOAT, "lower_bound", 0.0),
877908
]
878909
),
879910
"split_precomputation": split_precomputation,
@@ -891,7 +922,7 @@ def approx_rowwise_adagrad_with_counter() -> Dict[str, Any]:
891922

892923
approx_split_weight_update = """
893924
// dummy computation to avoid unused variable warning
894-
weight_new.fma_(grad, -multiplier);
925+
weight_new.fma_(grad, -0.001);
895926
assert(false); // approx rowwise AdaGrad is not supported on GPU
896927
"""
897928

@@ -915,6 +946,9 @@ def approx_rowwise_adagrad_with_counter() -> Dict[str, Any]:
915946
(FLOAT, "max_counter"),
916947
(FLOAT, "tail_id_threshold", 0.0),
917948
(INT, "is_tail_id_thresh_ratio", 0),
949+
(INT, "regularization_mode", 0),
950+
(FLOAT, "weight_norm_coefficient", 0.0),
951+
(FLOAT, "lower_bound", 0.0),
918952
]
919953
),
920954
"split_precomputation": rowwise_adagrad_with_counter_args[

0 commit comments

Comments
 (0)