Skip to content

Commit 0297a63

Browse files
spcypptfacebook-github-bot
authored andcommitted
Enable global weight decay to TBE (Backend) (#2516)
Summary: Pull Request resolved: #2516 Pull Request resolved: #2498 With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay. This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below: ``` global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1) ``` where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up. --- **Usage:** set ``` optimizer = OptimType.EXACT_ROWWISE_ADAGRAD weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL ``` e.g., ``` tbe = SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=[ (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds) ], optimizer=OptimType.EXACT_ROWWISE_ADAGRAD, learning_rate=0.1, eps=0.1, output_dtype=output_dtype, pooling_mode=pooling_mode, weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL, ) ``` Relevant diffs: D53866750 D55660277 D55660762 Differential Revision: D56285676
1 parent 0aecd17 commit 0297a63

15 files changed

+897
-350
lines changed

fbgemm_gpu/FbgemmGpu.cmake

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ set(VBE_OPTIMIZERS
8888
rowwise_adagrad_with_counter
8989
sgd)
9090

91+
# Optimizers with the GWD support
92+
set(GWD_OPTIMIZERS
93+
rowwise_adagrad)
94+
9195
# Individual optimizers (not fused with SplitTBE backward)
9296
set(DEFUSED_OPTIMIZERS
9397
rowwise_adagrad)
@@ -136,6 +140,8 @@ set(gen_gpu_kernel_source_files
136140
"gen_embedding_forward_dense_unweighted_codegen_cuda.cu"
137141
"gen_embedding_forward_split_weighted_codegen_cuda.cu"
138142
"gen_embedding_forward_split_unweighted_codegen_cuda.cu"
143+
"gen_embedding_forward_split_weighted_gwd_codegen_cuda.cu"
144+
"gen_embedding_forward_split_unweighted_gwd_codegen_cuda.cu"
139145
"gen_embedding_backward_dense_indice_weights_codegen_cuda.cu"
140146
"gen_embedding_backward_split_indice_weights_codegen_cuda.cu"
141147
"gen_embedding_forward_split_weighted_vbe_codegen_cuda.cu"
@@ -186,6 +192,12 @@ foreach(wdesc weighted unweighted)
186192
"gen_embedding_backward_${wdesc}_vbe_split_device_kernel.cuh")
187193
endforeach()
188194

195+
# Generate GWD files
196+
foreach(wdesc weighted unweighted)
197+
list(APPEND gen_gpu_kernel_source_files
198+
"gen_embedding_forward_split_${wdesc}_gwd_kernel.cu")
199+
endforeach()
200+
189201
set(gen_cpu_source_files
190202
"gen_embedding_forward_quantized_unweighted_codegen_cpu.cpp"
191203
"gen_embedding_forward_quantized_weighted_codegen_cpu.cpp"
@@ -252,6 +264,16 @@ foreach(optimizer ${VBE_OPTIMIZERS})
252264
endforeach()
253265
endforeach()
254266

267+
foreach(optimizer ${GWD_OPTIMIZERS})
268+
# GWD is not supported in nobag
269+
foreach(wdesc weighted unweighted)
270+
list(APPEND gen_gpu_kernel_source_files
271+
"gen_embedding_backward_${optimizer}_split_${wdesc}_gwd_cuda.cu"
272+
"gen_embedding_backward_${optimizer}_split_${wdesc}_gwd_kernel_cta.cu"
273+
"gen_embedding_backward_${optimizer}_split_${wdesc}_gwd_kernel_warp.cu")
274+
endforeach()
275+
endforeach()
276+
255277
foreach(optimizer ${DEFUSED_OPTIMIZERS})
256278
list(APPEND gen_defused_optim_source_files
257279
"gen_embedding_optimizer_${optimizer}_split.cpp"

fbgemm_gpu/codegen/genscript/generate_backward_split.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,17 @@ def render_backward_templates(
3535
optimizer: str,
3636
filename_format: str,
3737
kwargs: Dict[str, Any],
38+
is_gwd: bool = False,
3839
) -> None:
3940
if not kwargs.get("has_gpu_support"):
4041
return
41-
vbe_options = [True, False] if kwargs.get("has_vbe_support") else [False]
42+
vbe_options = (
43+
[True, False] if (kwargs.get("has_vbe_support") and not is_gwd) else [False]
44+
)
4245
template = CodeTemplate.load(template_filepath)
4346

4447
for weighted in [True, False]:
45-
for nobag in [True, False]:
48+
for nobag in [True, False] if (not is_gwd) else [False]:
4649
for vbe in vbe_options:
4750
if (not nobag or (not weighted and not vbe)) and (
4851
not kwargs.get("dense") or not vbe
@@ -56,6 +59,7 @@ def render_backward_templates(
5659
is_index_select=False,
5760
kdesc=wdesc,
5861
**kwargs,
62+
is_gwd=is_gwd,
5963
)
6064

6165
@staticmethod
@@ -90,6 +94,29 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
9094
filename_format,
9195
kwargs,
9296
)
97+
# Generate the global weight decay CUDA kernels
98+
if kwargs.get("has_global_weight_decay_support"):
99+
for template_filepath, filename_format in [
100+
(
101+
"training/backward/embedding_backward_split_kernel_cta_template.cu",
102+
"gen_embedding_backward_{}_split_{}_gwd_kernel_cta.cu",
103+
),
104+
(
105+
"training/backward/embedding_backward_split_kernel_warp_template.cu",
106+
"gen_embedding_backward_{}_split_{}_gwd_kernel_warp.cu",
107+
),
108+
(
109+
"training/backward/embedding_backward_split_template.cu",
110+
"gen_embedding_backward_{}_split_{}_gwd_cuda.cu",
111+
),
112+
]:
113+
BackwardSplitGenerator.render_backward_templates(
114+
template_filepath,
115+
optimizer,
116+
filename_format,
117+
kwargs,
118+
is_gwd=True,
119+
)
93120

94121
# Generate optimizer kernel
95122
CodeTemplate.load(

fbgemm_gpu/codegen/genscript/generate_forward_split.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def render_forward_templates(
2626
dense_options: List[bool],
2727
nobag_options: List[bool],
2828
vbe_options: List[bool],
29+
is_gwd: bool = False,
2930
) -> None:
3031
template = CodeTemplate.load(template_filepath)
3132
for dense in dense_options:
@@ -51,6 +52,7 @@ def render_forward_templates(
5152
nobag=nobag,
5253
vbe=vbe,
5354
is_index_select=False,
55+
is_gwd=is_gwd,
5456
)
5557

5658
@staticmethod
@@ -98,6 +100,15 @@ def generate_kernels() -> None:
98100
nobag_options=[False], # nobag is not used
99101
vbe_options=[True, False],
100102
)
103+
# Generate the CUDA host code for global weight decay
104+
ForwardSplitGenerator.render_forward_templates(
105+
"training/forward/embedding_forward_split_template.cu",
106+
"gen_embedding_forward_{}_gwd_codegen_cuda.cu",
107+
dense_options=[False],
108+
nobag_options=[False], # nobag is not used
109+
vbe_options=[False],
110+
is_gwd=True,
111+
)
101112

102113
# Generate the meta kernels
103114
ForwardSplitGenerator.render_forward_templates(
@@ -116,6 +127,15 @@ def generate_kernels() -> None:
116127
nobag_options=[True, False],
117128
vbe_options=[True, False],
118129
)
130+
# Generate the global weight decay CUDA kernels
131+
ForwardSplitGenerator.render_forward_templates(
132+
"training/forward/embedding_forward_split_kernel_template.cu",
133+
"gen_embedding_forward_{}_gwd_kernel.cu",
134+
dense_options=[False],
135+
nobag_options=[False],
136+
vbe_options=[False],
137+
is_gwd=True,
138+
)
119139

120140
# Generate the v2 CUDA kernels
121141
ForwardSplitGenerator.render_forward_templates(

fbgemm_gpu/codegen/genscript/jinja_environment.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,48 @@ def has_experimental_support(
298298
return not dense and not nobag and not vbe and not is_index_select and not is_rocm
299299

300300

301+
def is_valid_gwd_config(
302+
dense: bool,
303+
nobag: bool,
304+
vbe: bool,
305+
is_index_select: bool,
306+
is_rocm: bool,
307+
has_global_weight_decay_support: bool = True,
308+
) -> bool:
309+
"""
310+
Check if the given combination of configs is valid for global weight decay support
311+
- `has_global_weight_decay_support` is whether global weight decay is available for
312+
an optimizer, but not all configs of such optimizer offer global weight decay support
313+
- any updates to the configs need to be reflected in embedding_backward_split_host_template.cpp
314+
- global weight decay does not support dense, nobag, vbe, is_index_select, and is_rocm
315+
"""
316+
return (
317+
not dense
318+
and not nobag
319+
and not vbe
320+
and not is_index_select
321+
and not is_rocm
322+
and has_global_weight_decay_support
323+
)
324+
325+
326+
def compute_global_weight_decay(is_global_weight_decay_kernel: bool) -> str:
327+
"""
328+
For global weight decay kernel, compute the global weight decay value
329+
and update prev_iter to be current iteration
330+
This is to used in both warp and cta kernels.
331+
"""
332+
if is_global_weight_decay_kernel:
333+
return """
334+
const auto global_weight_decay = std::pow(weight_decay_base, iter - prev_iter_dev[linear_index] - 1);
335+
if (threadIdx.x == 0) {
336+
prev_iter_dev[linear_index] = iter;
337+
}
338+
"""
339+
else:
340+
return ""
341+
342+
301343
################################################################################
302344
# Register Helper Functions in Jinja Environment
303345
################################################################################
@@ -311,7 +353,8 @@ def has_experimental_support(
311353
env.globals["dispatch_vec_blocking_kernel"] = dispatch_vec_blocking_kernel
312354
env.globals["is_valid_forward_config"] = is_valid_forward_config
313355
env.globals["has_experimental_support"] = has_experimental_support
314-
356+
env.globals["is_valid_gwd_config"] = is_valid_gwd_config
357+
env.globals["compute_global_weight_decay"] = compute_global_weight_decay
315358

316359
################################################################################
317360
# Filter functions in Jinja Environment

fbgemm_gpu/codegen/genscript/optimizers.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def dense() -> Dict[str, Any]:
4141
"has_cpu_support": True,
4242
"has_gpu_support": True,
4343
"has_vbe_support": False,
44+
"has_global_weight_decay_support": False,
4445
}
4546

4647

@@ -84,6 +85,7 @@ def adagrad() -> Dict[str, Any]:
8485
"has_cpu_support": True,
8586
"has_gpu_support": True,
8687
"has_vbe_support": False,
88+
"has_global_weight_decay_support": False,
8789
}
8890

8991

@@ -191,7 +193,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
191193
if (weight_decay_mode == 1) {
192194
// L2 regularization
193195
correction = 1.0 - multiplier * weight_decay;
194-
} else if (weight_decay_mode == 2) {
196+
} else if (weight_decay_mode == 2 || weight_decay_mode == 5) {
195197
// Decoupled weight decay
196198
correction = 1.0 - learning_rate * weight_decay;
197199
} else {
@@ -221,7 +223,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
221223
if (weight_decay_mode == 1) {
222224
// L2 regularization
223225
correction = 1.0 - multiplier * weight_decay;
224-
} else if (weight_decay_mode == 2) {
226+
} else if (weight_decay_mode == 2 || weight_decay_mode == 5) {
225227
// Decoupled weight decay
226228
correction = 1.0 - learning_rate * weight_decay;
227229
} else {
@@ -252,6 +254,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
252254
"has_cpu_support": True,
253255
"has_gpu_support": True,
254256
"has_vbe_support": True,
257+
"has_global_weight_decay_support": True,
255258
}
256259

257260

@@ -282,6 +285,7 @@ def approx_rowwise_adagrad() -> Dict[str, Any]:
282285
"has_cpu_support": False,
283286
"has_gpu_support": False,
284287
"has_vbe_support": False,
288+
"has_global_weight_decay_support": False,
285289
}
286290

287291

@@ -387,6 +391,7 @@ def rowwise_adagrad_with_weight_decay() -> Dict[str, Any]:
387391
"has_cpu_support": False,
388392
"has_gpu_support": False,
389393
"has_vbe_support": False,
394+
"has_global_weight_decay_support": False,
390395
}
391396

392397

@@ -422,6 +427,7 @@ def approx_rowwise_adagrad_with_weight_decay() -> Dict[str, Any]:
422427
"has_cpu_support": False,
423428
"has_gpu_support": False,
424429
"has_vbe_support": False,
430+
"has_global_weight_decay_support": False,
425431
}
426432

427433

@@ -592,6 +598,7 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
592598
"has_cpu_support": False,
593599
"has_gpu_support": True,
594600
"has_vbe_support": True,
601+
"has_global_weight_decay_support": False,
595602
}
596603

597604

@@ -640,6 +647,7 @@ def approx_rowwise_adagrad_with_counter() -> Dict[str, Any]:
640647
"has_cpu_support": False,
641648
"has_gpu_support": False,
642649
"has_vbe_support": False,
650+
"has_global_weight_decay_support": False,
643651
}
644652

645653

@@ -717,6 +725,7 @@ def rowwise_weighted_adagrad() -> Dict[str, Any]:
717725
"has_cpu_support": False,
718726
"has_gpu_support": False,
719727
"has_vbe_support": False,
728+
"has_global_weight_decay_support": False,
720729
}
721730

722731

@@ -740,6 +749,7 @@ def sgd() -> Dict[str, Any]:
740749
"has_cpu_support": True,
741750
"has_gpu_support": True,
742751
"has_vbe_support": True,
752+
"has_global_weight_decay_support": False,
743753
}
744754

745755

@@ -763,6 +773,7 @@ def approx_sgd() -> Dict[str, Any]:
763773
"has_cpu_support": False,
764774
"has_gpu_support": False,
765775
"has_vbe_support": False,
776+
"has_global_weight_decay_support": False,
766777
}
767778

768779

@@ -840,6 +851,7 @@ def lamb() -> Dict[str, Any]:
840851
"has_cpu_support": False,
841852
"has_gpu_support": True,
842853
"has_vbe_support": False,
854+
"has_global_weight_decay_support": False,
843855
}
844856

845857

@@ -931,6 +943,7 @@ def partial_rowwise_lamb() -> Dict[str, Any]:
931943
"has_cpu_support": False,
932944
"has_gpu_support": True,
933945
"has_vbe_support": False,
946+
"has_global_weight_decay_support": False,
934947
}
935948

936949

@@ -986,6 +999,7 @@ def adam() -> Dict[str, Any]:
986999
"has_cpu_support": False,
9871000
"has_gpu_support": True,
9881001
"has_vbe_support": False,
1002+
"has_global_weight_decay_support": False,
9891003
}
9901004

9911005

@@ -1060,6 +1074,7 @@ def partial_rowwise_adam() -> Dict[str, Any]:
10601074
"has_cpu_support": False,
10611075
"has_gpu_support": True,
10621076
"has_vbe_support": False,
1077+
"has_global_weight_decay_support": False,
10631078
}
10641079

10651080

@@ -1124,6 +1139,7 @@ def lars_sgd() -> Dict[str, Any]:
11241139
"has_cpu_support": False,
11251140
"has_gpu_support": True,
11261141
"has_vbe_support": False,
1142+
"has_global_weight_decay_support": False,
11271143
}
11281144

11291145

@@ -1141,4 +1157,5 @@ def none_optimizer() -> Dict[str, Any]:
11411157
"has_cpu_support": False,
11421158
"has_gpu_support": True,
11431159
"has_vbe_support": False,
1160+
"has_global_weight_decay_support": False,
11441161
}

0 commit comments

Comments
 (0)