Skip to content

Commit 006a74e

Browse files
spcypptfacebook-github-bot
authored andcommitted
Enable global weight decay to TBE (Backend) (#2498)
Summary: With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay. This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below: ``` global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1) ``` where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up. --- **Usage:** set ``` optimizer = OptimType.EXACT_ROWWISE_ADAGRAD weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL ``` e.g., ``` tbe = SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=[ (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds) ], optimizer=OptimType.EXACT_ROWWISE_ADAGRAD, learning_rate=0.1, eps=0.1, output_dtype=output_dtype, pooling_mode=pooling_mode, weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL, ) ``` Relevant diffs: D53866750 D55660277 D55660762 Differential Revision: D56285676
1 parent bb5706a commit 006a74e

15 files changed

+539
-32
lines changed

fbgemm_gpu/codegen/genscript/generate_backward_split.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ def render_backward_templates(
3535
optimizer: str,
3636
filename_format: str,
3737
kwargs: Dict[str, Any],
38+
is_gwd: bool = False,
3839
) -> None:
3940
if not kwargs.get("has_gpu_support"):
4041
return
4142
vbe_options = [True, False] if kwargs.get("has_vbe_support") else [False]
43+
if is_gwd:
44+
vbe_options = [False]
4245
template = CodeTemplate.load(template_filepath)
4346

4447
for weighted in [True, False]:
@@ -56,6 +59,7 @@ def render_backward_templates(
5659
is_index_select=False,
5760
kdesc=wdesc,
5861
**kwargs,
62+
is_gwd=is_gwd,
5963
)
6064

6165
@staticmethod
@@ -90,6 +94,25 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
9094
filename_format,
9195
kwargs,
9296
)
97+
# Generate the backward split kernels
98+
if kwargs.get("has_global_weight_decay_support"):
99+
for template_filepath, filename_format in [
100+
(
101+
"training/backward/embedding_backward_split_kernel_cta_template.cu",
102+
"gen_embedding_backward_{}_split_{}_gwd_kernel_cta.cu",
103+
),
104+
(
105+
"training/backward/embedding_backward_split_kernel_warp_template.cu",
106+
"gen_embedding_backward_{}_split_{}_gwd_kernel_warp.cu",
107+
),
108+
]:
109+
BackwardSplitGenerator.render_backward_templates(
110+
template_filepath,
111+
optimizer,
112+
filename_format,
113+
kwargs,
114+
is_gwd=True,
115+
)
93116

94117
# Generate optimizer kernel
95118
CodeTemplate.load(

fbgemm_gpu/codegen/genscript/generate_forward_split.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def render_forward_templates(
2626
dense_options: List[bool],
2727
nobag_options: List[bool],
2828
vbe_options: List[bool],
29+
is_gwd: bool = False,
2930
) -> None:
3031
template = CodeTemplate.load(template_filepath)
3132
for dense in dense_options:
@@ -51,6 +52,7 @@ def render_forward_templates(
5152
nobag=nobag,
5253
vbe=vbe,
5354
is_index_select=False,
55+
is_gwd=is_gwd,
5456
)
5557

5658
@staticmethod
@@ -116,6 +118,14 @@ def generate_kernels() -> None:
116118
nobag_options=[True, False],
117119
vbe_options=[True, False],
118120
)
121+
ForwardSplitGenerator.render_forward_templates(
122+
"training/forward/embedding_forward_split_kernel_template.cu",
123+
"gen_embedding_forward_{}_gwd_kernel.cu",
124+
dense_options=[False],
125+
nobag_options=[False],
126+
vbe_options=[False],
127+
is_gwd=True,
128+
)
119129

120130
# Generate the v2 CUDA kernels
121131
ForwardSplitGenerator.render_forward_templates(

fbgemm_gpu/codegen/genscript/jinja_environment.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,92 @@ def has_experimental_support(
294294
return not dense and not nobag and not vbe and not is_index_select and not is_rocm
295295

296296

297+
def is_valid_forward_gwd_config(
298+
dense: bool, nobag: bool, vbe: bool, is_index_select: bool, is_rocm: bool
299+
) -> bool:
300+
"""
301+
Check if the given combination of configs for `forward` has global weight decay support
302+
- global weight decay does not support dense, nobag, vbe, is_index_select, and is_rocm
303+
"""
304+
return not dense and not nobag and not vbe and not is_index_select and not is_rocm
305+
306+
307+
def is_valid_backward_gwd_config(
308+
dense: bool,
309+
nobag: bool,
310+
vbe: bool,
311+
is_index_select: bool,
312+
is_rocm: bool,
313+
has_global_weight_decay_support: bool,
314+
) -> bool:
315+
"""
316+
Check if the given combination of configs for `backward` has global weight decay support
317+
- `has_global_weight_decay_support` is whether global weight decay is available for
318+
an optimizer, but not all configs of such optimizer offer global weight decay support
319+
- global weight decay does not support dense, nobag, vbe, is_index_select, and is_rocm
320+
"""
321+
return (
322+
not dense
323+
and not nobag
324+
and not vbe
325+
and not is_index_select
326+
and not is_rocm
327+
and has_global_weight_decay_support
328+
)
329+
330+
331+
def update_with_global_weight_decay(has_global_weight_decay_support: bool) -> str:
332+
"""
333+
Update weights with global weight decay value if has_global_weight_decay_support
334+
"""
335+
if has_global_weight_decay_support:
336+
return """
337+
weight_new.mul_(global_weight_decay);
338+
"""
339+
else:
340+
return ""
341+
342+
343+
def compute_global_weight_decay(is_global_weight_decay_kernel: bool) -> str:
344+
"""
345+
For global weight decay kernel, compute the global weight decay value
346+
and update prev_iter to be current iteration
347+
This is to used in both warp and cta kernels.
348+
"""
349+
if is_global_weight_decay_kernel:
350+
return """
351+
const auto global_weight_decay = std::pow(weight_decay_base, iter - prev_iter_dev_gwd[linear_index] - 1);
352+
if (threadIdx.x == 0) {
353+
prev_iter_dev_gwd[linear_index] = iter;
354+
}
355+
"""
356+
else:
357+
return ""
358+
359+
360+
def pass_gwd_to_update_table(
361+
is_global_weight_decay_kernel: bool,
362+
has_global_weight_decay_support: bool,
363+
) -> str:
364+
"""
365+
Pass correct parameter to the update_table_kernel
366+
- pass global weight decay when enabled and computed
367+
- pass 1.0 when not enabled
368+
- pass nothing for other configs
369+
This is to used in both warp and cta kernels.
370+
"""
371+
if is_global_weight_decay_kernel:
372+
# global weight decay is enabled
373+
return "global_weight_decay,"
374+
elif has_global_weight_decay_support:
375+
# table update kernel is per optimizer (e.g., rowwise adagrad).
376+
# But not all configs have gwd enabled or gwd support
377+
# (e.g., other modes, nobag kernels)
378+
return "1.0,"
379+
else:
380+
return ""
381+
382+
297383
################################################################################
298384
# Register Helper Functions in Jinja Environment
299385
################################################################################
@@ -307,7 +393,11 @@ def has_experimental_support(
307393
env.globals["dispatch_vec_blocking_kernel"] = dispatch_vec_blocking_kernel
308394
env.globals["is_valid_forward_config"] = is_valid_forward_config
309395
env.globals["has_experimental_support"] = has_experimental_support
310-
396+
env.globals["is_valid_forward_gwd_config"] = is_valid_forward_gwd_config
397+
env.globals["is_valid_backward_gwd_config"] = is_valid_backward_gwd_config
398+
env.globals["compute_global_weight_decay"] = compute_global_weight_decay
399+
env.globals["update_with_global_weight_decay"] = update_with_global_weight_decay
400+
env.globals["pass_gwd_to_update_table"] = pass_gwd_to_update_table
311401

312402
################################################################################
313403
# Filter functions in Jinja Environment

fbgemm_gpu/codegen/genscript/optimizers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
197197
if (weight_decay_mode == 1) {
198198
// L2 regularization
199199
correction = 1.0 - multiplier * weight_decay;
200-
} else if (weight_decay_mode == 2) {
200+
} else if (weight_decay_mode == 2 || weight_decay_mode == 5) {
201201
// Decoupled weight decay
202202
correction = 1.0 - learning_rate * weight_decay;
203203
} else {
@@ -227,7 +227,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
227227
if (weight_decay_mode == 1) {
228228
// L2 regularization
229229
correction = 1.0 - multiplier * weight_decay;
230-
} else if (weight_decay_mode == 2) {
230+
} else if (weight_decay_mode == 2 || weight_decay_mode == 5) {
231231
// Decoupled weight decay
232232
correction = 1.0 - learning_rate * weight_decay;
233233
} else {
@@ -258,6 +258,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
258258
"has_cpu_support": True,
259259
"has_gpu_support": True,
260260
"has_vbe_support": True,
261+
"has_global_weight_decay_support": True,
261262
}
262263

263264

0 commit comments

Comments
 (0)