You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Enable global weight decay to TBE (Backend) (#2498)
Summary:
With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.
This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```
e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
(E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
],
optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
learning_rate=0.1,
eps=0.1,
output_dtype=output_dtype,
pooling_mode=pooling_mode,
weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
)
```
Relevant diffs:
D53866750
D55660277
D55660762
Differential Revision: D56285676
0 commit comments